Add systemContentReplacement and Flags

This commit is contained in:
Kwaroran
2024-11-27 04:33:12 +09:00
parent cc8d753dc8
commit 597c8879fc
5 changed files with 364 additions and 100 deletions

View File

@@ -20,7 +20,7 @@ import {Ollama} from 'ollama/dist/browser.mjs'
import { applyChatTemplate } from "./templates/chatTemplate";
import { OobaParams } from "./prompt";
import { extractJSON, getOpenAIJSONSchema } from "./templates/jsonSchema";
import { getModelInfo, LLMFormat, type LLMModel } from "../model/modellist";
import { getModelInfo, LLMFlags, LLMFormat, type LLMModel } from "../model/modellist";
@@ -88,7 +88,7 @@ interface OaiFunctions {
}
type Parameter = 'temperature'|'top_k'|'repetition_penalty'|'min_p'|'top_a'|'top_p'|'frequency_penalty'|'presence_penalty'
export type Parameter = 'temperature'|'top_k'|'repetition_penalty'|'min_p'|'top_a'|'top_p'|'frequency_penalty'|'presence_penalty'
type ParameterMap = {
[key in Parameter]?: string;
};
@@ -182,6 +182,63 @@ export interface OpenAIChatExtra {
multimodals?:MultiModal[]
}
function reformater(formated:OpenAIChat[],modelInfo:LLMModel){
const db = getDatabase()
let systemPrompt:OpenAIChat|null = null
if(!modelInfo.flags.includes(LLMFlags.hasFullSystemPrompt)){
if(modelInfo.flags.includes(LLMFlags.hasFirstSystemPrompt)){
if(formated[0].role === 'system'){
systemPrompt = formated[0]
formated = formated.slice(1)
}
}
for(let i=0;i<formated.length;i++){
if(formated[i].role === 'system'){
formated[i].content = db.systemContentReplacement.replace('{{slot}}', formated[i].content)
formated[i].role = db.systemRoleReplacement
}
}
}
if(modelInfo.flags.includes(LLMFlags.requiresAlternateRole)){
let newFormated:OpenAIChat[] = []
for(let i=0;i<formated.length;i++){
const m = formated[i]
if(newFormated.length === 0){
newFormated.push(m)
continue
}
if(newFormated[newFormated.length-1].role === m.role){
newFormated[newFormated.length-1].content += '\n' + m.content
continue
}
else{
newFormated.push(m)
}
}
formated = newFormated
}
if(modelInfo.flags.includes(LLMFlags.mustStartWithUserInput)){
if(formated.length === 0 || formated[0].role !== 'user'){
formated.unshift({
role: 'user',
content: ' '
})
}
}
if(systemPrompt){
formated.unshift(systemPrompt)
}
return formated
}
export async function requestChatDataMain(arg:requestDataArgument, model:'model'|'submodel', abortSignal:AbortSignal=null):Promise<requestDataResponse> {
const db = getDatabase()
@@ -206,6 +263,8 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model'
const format = targ.modelInfo.format
targ.formated = reformater(targ.formated, targ.modelInfo)
switch(format){
case LLMFormat.OpenAICompatible:
case LLMFormat.Mistral:
@@ -437,14 +496,13 @@ async function requestOpenAI(arg:RequestDataArgumentExtended):Promise<requestDat
}
const res = await globalFetch(arg.customURL ?? "https://api.mistral.ai/v1/chat/completions", {
body: {
body: applyParameters({
model: requestModel,
messages: reformatedChat,
temperature: arg.temperature,
max_tokens: arg.maxTokens,
top_p: db.top_p,
safe_prompt: false
},
safe_prompt: false,
max_tokens: arg.maxTokens,
}, ['temperature', 'presence_penalty', 'frequency_penalty'] ),
headers: {
"Authorization": "Bearer " + db.mistralKey,
},