From 8aacea9f1a7ab9c2e7bd24505a08235fcf3b3639 Mon Sep 17 00:00:00 2001 From: aegkmq <140575296+aegkmq@users.noreply.github.com> Date: Tue, 1 Aug 2023 15:58:41 +0900 Subject: [PATCH] [fix] oobabooga prompts --- src/lib/Setting/Pages/BotSettings.svelte | 6 +- src/ts/process/request.ts | 5 +- src/ts/process/stringlize.ts | 78 ++++++++++++++++-------- 3 files changed, 58 insertions(+), 31 deletions(-) diff --git a/src/lib/Setting/Pages/BotSettings.svelte b/src/lib/Setting/Pages/BotSettings.svelte index f97d3c3b..8bdfc7f1 100644 --- a/src/lib/Setting/Pages/BotSettings.svelte +++ b/src/lib/Setting/Pages/BotSettings.svelte @@ -285,11 +285,11 @@ {#if $DataBase.ooba.formating.custom}
User Prefix - + Assistant Prefix - + Seperator - +
{/if} {:else if $DataBase.aiModel.startsWith('novelai')} diff --git a/src/ts/process/request.ts b/src/ts/process/request.ts index 933161ea..1bfc35dc 100644 --- a/src/ts/process/request.ts +++ b/src/ts/process/request.ts @@ -3,7 +3,7 @@ import type { OpenAIChat, OpenAIChatFull } from "."; import { DataBase, setDatabase, type character } from "../storage/database"; import { pluginProcess } from "../plugins/plugins"; import { language } from "../../lang"; -import { stringlizeAINChat, stringlizeChat, stringlizeChatOba, unstringlizeAIN, unstringlizeChat } from "./stringlize"; +import { stringlizeAINChat, stringlizeChat, stringlizeChatOba, getStopStrings, unstringlizeAIN, unstringlizeChat } from "./stringlize"; import { globalFetch, isNodeServer, isTauri } from "../storage/globalApi"; import { sleep } from "../util"; import { createDeep } from "./deepai"; @@ -375,11 +375,12 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model' let DURL = db.textgenWebUIURL let bodyTemplate:any const proompt = stringlizeChatOba(formated, currentChar?.name ?? '') + const stopStrings = getStopStrings() if(!DURL.endsWith('generate')){ DURL = DURL + "/v1/generate" } - const stopStrings = [`\nUser:`,`\nuser:`,`\n${db.username}:`] console.log(proompt) + console.log(stopStrings) bodyTemplate = { 'max_new_tokens': db.maxResponse, 'do_sample': true, diff --git a/src/ts/process/stringlize.ts b/src/ts/process/stringlize.ts index 546924b6..2f8cb72b 100644 --- a/src/ts/process/stringlize.ts +++ b/src/ts/process/stringlize.ts @@ -23,43 +23,69 @@ export function stringlizeChat(formated:OpenAIChat[], char:string = ''){ return resultString.join('\n\n') + `\n\n${char}:` } +function appendWhitespace(prefix:string, seperator:string=" ") { + if(!"> \n".includes(prefix[prefix.length-1])){ + prefix += seperator.includes("\n\n") ? "\n" : " " + } + return prefix +} export function stringlizeChatOba(formated:OpenAIChat[], char:string = ''){ const db = get(DataBase) let resultString:string[] = [] - if(db.ooba.formating.custom){ - for(const form of formated){ - if(form.role === 'system'){ - resultString.push(form.content) - } - else if(form.name){ - resultString.push(db.ooba.formating.userPrefix + form.content + db.ooba.formating.seperator) - } - else if(form.role === 'assistant' && char){ - resultString.push(db.ooba.formating.assistantPrefix + form.content + db.ooba.formating.seperator) - - } - else{ - resultString.push(form.content) - } - } - return resultString.join('\n\n') + `\n\n${db.ooba.formating.assistantPrefix}:` + let { custom, userPrefix, assistantPrefix, seperator } = db.ooba.formating; + if(!custom || !seperator){ + seperator = "\n\n" } + for(const form of formated){ - if(form.role === 'system'){ - resultString.push(form.content) + if(form.content === "[Start a new chat]"){ + continue } - else if(form.name){ - resultString.push(form.name + ": " + form.content) + let prefix = "" + if(form.role !== 'system' && form.name){ + prefix = custom ? appendWhitespace(userPrefix, seperator) : form.name + ": " } else if(form.role === 'assistant' && char){ - resultString.push(char + ": " + form.content) - + prefix = custom ? appendWhitespace(assistantPrefix, seperator) : char + ": " } - else{ - resultString.push(form.content) + resultString.push(prefix + form.content) + } + const name = custom ? assistantPrefix : char + ":" + resultString.push(name) + return resultString.join(seperator) +} + +const userStrings = ["user", "human", "input", "inst", "instruction"] +function toTitleCase(s:string){ + return s[0].toUpperCase() + s.slice(1).toLowerCase() +} +export function getStopStrings(){ + const db = get(DataBase) + let { custom, userPrefix, seperator } = db.ooba.formating; + if(!custom || !seperator){ + seperator = "\n" + } + const { username } = db + const stopStrings = [ + "GPT4 User", + userPrefix, + `${username}:`, + ] + if(seperator !== " "){ + stopStrings.push(seperator + username) + } + for (const user of userStrings){ + for (const u of [ + user.toLowerCase(), + user.toUpperCase(), + user.replace(/\w\S*/g, toTitleCase), + ]){ + stopStrings.push(`${u}:`) + stopStrings.push(`<<${u}>>`) + stopStrings.push(`### ${u}`) } } - return resultString.join('\n\n') + `\n\n${char}:` + return [...new Set(stopStrings)] } export function unstringlizeChat(text:string, formated:OpenAIChat[], char:string = ''){