[fix] oobabooga prompts

This commit is contained in:
aegkmq
2023-08-01 15:58:41 +09:00
parent 5ef6f00a8c
commit 8aacea9f1a
3 changed files with 58 additions and 31 deletions

View File

@@ -3,7 +3,7 @@ import type { OpenAIChat, OpenAIChatFull } from ".";
import { DataBase, setDatabase, type character } from "../storage/database";
import { pluginProcess } from "../plugins/plugins";
import { language } from "../../lang";
import { stringlizeAINChat, stringlizeChat, stringlizeChatOba, unstringlizeAIN, unstringlizeChat } from "./stringlize";
import { stringlizeAINChat, stringlizeChat, stringlizeChatOba, getStopStrings, unstringlizeAIN, unstringlizeChat } from "./stringlize";
import { globalFetch, isNodeServer, isTauri } from "../storage/globalApi";
import { sleep } from "../util";
import { createDeep } from "./deepai";
@@ -375,11 +375,12 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model'
let DURL = db.textgenWebUIURL
let bodyTemplate:any
const proompt = stringlizeChatOba(formated, currentChar?.name ?? '')
const stopStrings = getStopStrings()
if(!DURL.endsWith('generate')){
DURL = DURL + "/v1/generate"
}
const stopStrings = [`\nUser:`,`\nuser:`,`\n${db.username}:`]
console.log(proompt)
console.log(stopStrings)
bodyTemplate = {
'max_new_tokens': db.maxResponse,
'do_sample': true,

View File

@@ -23,43 +23,69 @@ export function stringlizeChat(formated:OpenAIChat[], char:string = ''){
return resultString.join('\n\n') + `\n\n${char}:`
}
function appendWhitespace(prefix:string, seperator:string=" ") {
if(!"> \n".includes(prefix[prefix.length-1])){
prefix += seperator.includes("\n\n") ? "\n" : " "
}
return prefix
}
export function stringlizeChatOba(formated:OpenAIChat[], char:string = ''){
const db = get(DataBase)
let resultString:string[] = []
if(db.ooba.formating.custom){
for(const form of formated){
if(form.role === 'system'){
resultString.push(form.content)
}
else if(form.name){
resultString.push(db.ooba.formating.userPrefix + form.content + db.ooba.formating.seperator)
}
else if(form.role === 'assistant' && char){
resultString.push(db.ooba.formating.assistantPrefix + form.content + db.ooba.formating.seperator)
}
else{
resultString.push(form.content)
}
}
return resultString.join('\n\n') + `\n\n${db.ooba.formating.assistantPrefix}:`
let { custom, userPrefix, assistantPrefix, seperator } = db.ooba.formating;
if(!custom || !seperator){
seperator = "\n\n"
}
for(const form of formated){
if(form.role === 'system'){
resultString.push(form.content)
if(form.content === "[Start a new chat]"){
continue
}
else if(form.name){
resultString.push(form.name + ": " + form.content)
let prefix = ""
if(form.role !== 'system' && form.name){
prefix = custom ? appendWhitespace(userPrefix, seperator) : form.name + ": "
}
else if(form.role === 'assistant' && char){
resultString.push(char + ": " + form.content)
prefix = custom ? appendWhitespace(assistantPrefix, seperator) : char + ": "
}
else{
resultString.push(form.content)
resultString.push(prefix + form.content)
}
const name = custom ? assistantPrefix : char + ":"
resultString.push(name)
return resultString.join(seperator)
}
const userStrings = ["user", "human", "input", "inst", "instruction"]
function toTitleCase(s:string){
return s[0].toUpperCase() + s.slice(1).toLowerCase()
}
export function getStopStrings(){
const db = get(DataBase)
let { custom, userPrefix, seperator } = db.ooba.formating;
if(!custom || !seperator){
seperator = "\n"
}
const { username } = db
const stopStrings = [
"GPT4 User",
userPrefix,
`${username}:`,
]
if(seperator !== " "){
stopStrings.push(seperator + username)
}
for (const user of userStrings){
for (const u of [
user.toLowerCase(),
user.toUpperCase(),
user.replace(/\w\S*/g, toTitleCase),
]){
stopStrings.push(`${u}:`)
stopStrings.push(`<<${u}>>`)
stopStrings.push(`### ${u}`)
}
}
return resultString.join('\n\n') + `\n\n${char}:`
return [...new Set(stopStrings)]
}
export function unstringlizeChat(text:string, formated:OpenAIChat[], char:string = ''){