[feat] accurate tokenizing

This commit is contained in:
kwaroran
2023-11-23 17:52:00 +09:00
parent dab121c9c7
commit 634fe418b4
6 changed files with 116 additions and 16 deletions

View File

@@ -5,7 +5,7 @@
import { DataBase } from "src/ts/storage/database"; import { DataBase } from "src/ts/storage/database";
import { customProviderStore, getCurrentPluginMax } from "src/ts/plugins/plugins"; import { customProviderStore, getCurrentPluginMax } from "src/ts/plugins/plugins";
import { getModelMaxContext, isTauri } from "src/ts/storage/globalApi"; import { getModelMaxContext, isTauri } from "src/ts/storage/globalApi";
import { tokenize } from "src/ts/tokenizer"; import { tokenize, tokenizeAccurate } from "src/ts/tokenizer";
import ModelList from "src/lib/UI/ModelList.svelte"; import ModelList from "src/lib/UI/ModelList.svelte";
import DropList from "src/lib/SideBars/DropList.svelte"; import DropList from "src/lib/SideBars/DropList.svelte";
import { PlusIcon, TrashIcon } from "lucide-svelte"; import { PlusIcon, TrashIcon } from "lucide-svelte";
@@ -40,10 +40,10 @@
export let goPromptTemplate = () => {} export let goPromptTemplate = () => {}
async function loadTokenize(){ async function loadTokenize(){
tokens.mainPrompt = await tokenize($DataBase.mainPrompt) tokens.mainPrompt = await tokenizeAccurate($DataBase.mainPrompt)
tokens.jailbreak = await tokenize($DataBase.jailbreak) tokens.jailbreak = await tokenizeAccurate($DataBase.jailbreak)
tokens.globalNote = await tokenize($DataBase.globalNote) tokens.globalNote = await tokenizeAccurate($DataBase.globalNote)
tokens.autoSuggest = await tokenize($DataBase.autoSuggestPrompt) tokens.autoSuggest = await tokenizeAccurate($DataBase.autoSuggestPrompt)
} }
let advancedBotSettings = false let advancedBotSettings = false

View File

@@ -2,7 +2,7 @@
import { ArrowLeft, PlusIcon } from "lucide-svelte"; import { ArrowLeft, PlusIcon } from "lucide-svelte";
import { language } from "src/lang"; import { language } from "src/lang";
import ProomptItem from "src/lib/UI/ProomptItem.svelte"; import ProomptItem from "src/lib/UI/ProomptItem.svelte";
import type { Proompt } from "src/ts/process/proompt"; import { tokenizePreset, type Proompt } from "src/ts/process/proompt";
import { templateCheck } from "src/ts/process/templates/templateCheck"; import { templateCheck } from "src/ts/process/templates/templateCheck";
import { DataBase } from "src/ts/storage/database"; import { DataBase } from "src/ts/storage/database";
@@ -10,8 +10,15 @@
let opened = 0 let opened = 0
let warns: string[] = [] let warns: string[] = []
export let onGoBack: () => void = () => {} export let onGoBack: () => void = () => {}
let tokens = 0
executeTokenize($DataBase.promptTemplate)
async function executeTokenize(prest: Proompt[]){
tokens = await tokenizePreset(prest)
}
$: warns = templateCheck($DataBase) $: warns = templateCheck($DataBase)
$: executeTokenize($DataBase.promptTemplate)
</script> </script>
<h2 class="mb-2 text-2xl font-bold mt-2 items-center flex"> <h2 class="mb-2 text-2xl font-bold mt-2 items-center flex">
@@ -71,4 +78,6 @@
type2: 'normal' type2: 'normal'
}) })
$DataBase.promptTemplate = value $DataBase.promptTemplate = value
}}><PlusIcon /></button> }}><PlusIcon /></button>
<span class="text-textcolor2 mb-6 text-sm mt-2">{tokens} {language.tokens}</span>

View File

@@ -1,6 +1,6 @@
<script lang="ts"> <script lang="ts">
import { language } from "../../lang"; import { language } from "../../lang";
import { tokenize } from "../../ts/tokenizer"; import { tokenize, tokenizeAccurate } from "../../ts/tokenizer";
import { DataBase, saveImage as saveAsset, type Database, type character, type groupChat } from "../../ts/storage/database"; import { DataBase, saveImage as saveAsset, type Database, type character, type groupChat } from "../../ts/storage/database";
import { selectedCharID } from "../../ts/stores"; import { selectedCharID } from "../../ts/stores";
import { PlusIcon, SmileIcon, TrashIcon, UserIcon, ActivityIcon, BookIcon, LoaderIcon, User, DnaIcon, CurlyBraces, Volume2Icon, XIcon } from 'lucide-svelte' import { PlusIcon, SmileIcon, TrashIcon, UserIcon, ActivityIcon, BookIcon, LoaderIcon, User, DnaIcon, CurlyBraces, Volume2Icon, XIcon } from 'lucide-svelte'
@@ -51,17 +51,17 @@
if(lasttokens.desc !== cha.desc){ if(lasttokens.desc !== cha.desc){
if(cha.desc){ if(cha.desc){
lasttokens.desc = cha.desc lasttokens.desc = cha.desc
tokens.desc = await tokenize(cha.desc) tokens.desc = await tokenizeAccurate(cha.desc)
} }
} }
if(lasttokens.firstMsg !==chara.firstMessage){ if(lasttokens.firstMsg !==chara.firstMessage){
lasttokens.firstMsg = chara.firstMessage lasttokens.firstMsg = chara.firstMessage
tokens.firstMsg = await tokenize(chara.firstMessage) tokens.firstMsg = await tokenizeAccurate(chara.firstMessage)
} }
} }
if(lasttokens.localNote !== currentChar.data.chats[currentChar.data.chatPage].note){ if(lasttokens.localNote !== currentChar.data.chats[currentChar.data.chatPage].note){
lasttokens.localNote = currentChar.data.chats[currentChar.data.chatPage].note lasttokens.localNote = currentChar.data.chats[currentChar.data.chatPage].note
tokens.localNote = await tokenize(currentChar.data.chats[currentChar.data.chatPage].note) tokens.localNote = await tokenizeAccurate(currentChar.data.chats[currentChar.data.chatPage].note)
} }

View File

@@ -11,6 +11,7 @@ import { selectedCharID } from './stores';
import { calcString } from './process/infunctions'; import { calcString } from './process/infunctions';
import { findCharacterbyId } from './util'; import { findCharacterbyId } from './util';
import { getInlayImage } from './image'; import { getInlayImage } from './image';
import { cloneDeep } from 'lodash';
const convertora = new showdown.Converter({ const convertora = new showdown.Converter({
simpleLineBreaks: true, simpleLineBreaks: true,
@@ -346,7 +347,14 @@ function wppParser(data:string){
const rgx = /(?:{{|<)(.+?)(?:}}|>)/gm const rgx = /(?:{{|<)(.+?)(?:}}|>)/gm
type matcherArg = {chatID:number,db:Database,chara:character|string,rmVar:boolean,var?:{[key:string]:string}} type matcherArg = {
chatID:number,
db:Database,
chara:character|string,
rmVar:boolean,
var?:{[key:string]:string}
tokenizeAccurate?:boolean
}
const matcher = (p1:string,matcherArg:matcherArg) => { const matcher = (p1:string,matcherArg:matcherArg) => {
if(p1.length > 10000){ if(p1.length > 10000){
return '' return ''
@@ -390,7 +398,7 @@ const matcher = (p1:string,matcherArg:matcherArg) => {
case 'bot':{ case 'bot':{
let selectedChar = get(selectedCharID) let selectedChar = get(selectedCharID)
let currentChar = db.characters[selectedChar] let currentChar = db.characters[selectedChar]
if(currentChar.type !== 'group'){ if(currentChar && currentChar.type !== 'group'){
return currentChar.name return currentChar.name
} }
if(chara){ if(chara){
@@ -484,6 +492,9 @@ const matcher = (p1:string,matcherArg:matcherArg) => {
return '' return ''
} }
case 'time':{ case 'time':{
if(matcherArg.tokenizeAccurate){
return `00:00:00`
}
if(chatID === -1){ if(chatID === -1){
return "[Cannot get time]" return "[Cannot get time]"
} }
@@ -499,6 +510,9 @@ const matcher = (p1:string,matcherArg:matcherArg) => {
return date.toLocaleTimeString() return date.toLocaleTimeString()
} }
case 'date':{ case 'date':{
if(matcherArg.tokenizeAccurate){
return `00:00:00`
}
if(chatID === -1){ if(chatID === -1){
return "[Cannot get time]" return "[Cannot get time]"
} }
@@ -513,6 +527,9 @@ const matcher = (p1:string,matcherArg:matcherArg) => {
return date.toLocaleDateString() return date.toLocaleDateString()
} }
case 'idle_duration':{ case 'idle_duration':{
if(matcherArg.tokenizeAccurate){
return `00:00:00`
}
if(chatID === -1){ if(chatID === -1){
return "[Cannot get time]" return "[Cannot get time]"
} }
@@ -622,11 +639,17 @@ const matcher = (p1:string,matcherArg:matcherArg) => {
if(p1.startsWith('random')){ if(p1.startsWith('random')){
if(p1.startsWith('random::')){ if(p1.startsWith('random::')){
const randomIndex = Math.floor(Math.random() * (arra.length - 1)) + 1 const randomIndex = Math.floor(Math.random() * (arra.length - 1)) + 1
if(matcherArg.tokenizeAccurate){
return arra[0]
}
return arra[randomIndex] return arra[randomIndex]
} }
else{ else{
const arr = p1.split(/\:|\,/g) const arr = p1.split(/\:|\,/g)
const randomIndex = Math.floor(Math.random() * (arr.length - 1)) + 1 const randomIndex = Math.floor(Math.random() * (arr.length - 1)) + 1
if(matcherArg.tokenizeAccurate){
return arra[0]
}
return arr[randomIndex] return arr[randomIndex]
} }
} }
@@ -656,7 +679,7 @@ const smMatcher = (p1:string,matcherArg:matcherArg) => {
case 'bot':{ case 'bot':{
let selectedChar = get(selectedCharID) let selectedChar = get(selectedCharID)
let currentChar = db.characters[selectedChar] let currentChar = db.characters[selectedChar]
if(currentChar.type !== 'group'){ if(currentChar && currentChar.type !== 'group'){
return currentChar.name return currentChar.name
} }
if(chara){ if(chara){
@@ -707,6 +730,7 @@ export function risuChatParser(da:string, arg:{
chara?:string|character|groupChat chara?:string|character|groupChat
rmVar?:boolean, rmVar?:boolean,
var?:{[key:string]:string} var?:{[key:string]:string}
tokenizeAccurate?:boolean
} = {}):string{ } = {}):string{
const chatID = arg.chatID ?? -1 const chatID = arg.chatID ?? -1
const db = arg.db ?? get(DataBase) const db = arg.db ?? get(DataBase)
@@ -724,6 +748,13 @@ export function risuChatParser(da:string, arg:{
chara = aChara chara = aChara
} }
} }
if(arg.tokenizeAccurate){
const db = arg.db ?? get(DataBase)
const selchar = chara ?? db.characters[get(selectedCharID)]
if(!selchar){
chara = 'bot'
}
}
let pointer = 0; let pointer = 0;
@@ -731,12 +762,16 @@ export function risuChatParser(da:string, arg:{
let pf = performance.now() let pf = performance.now()
let v = new Uint8Array(512) let v = new Uint8Array(512)
let pureMode = false let pureMode = false
let commentMode = false
let commentLatest:string[] = [""]
let commentV = new Uint8Array(512)
const matcherObj = { const matcherObj = {
chatID: chatID, chatID: chatID,
chara: chara, chara: chara,
rmVar: arg.rmVar ?? false, rmVar: arg.rmVar ?? false,
db: db, db: db,
var: arg.var ?? null var: arg.var ?? null,
tokenizeAccurate: arg.tokenizeAccurate ?? false
} }
while(pointer < da.length){ while(pointer < da.length){
switch(da[pointer]){ switch(da[pointer]){
@@ -791,6 +826,25 @@ export function risuChatParser(da:string, arg:{
pureMode = false pureMode = false
break break
} }
case 'Comment':{
if(!commentMode){
commentMode = true
commentLatest = nested.map((f) => f)
if(commentLatest[0].endsWith('\n')){
commentLatest[0] = commentLatest[0].substring(0, commentLatest[0].length - 1)
}
commentV = new Uint8Array(v)
}
break
}
case '/Comment':{
if(commentMode){
nested = commentLatest
v = commentV
commentMode = false
}
break
}
default:{ default:{
const mc = (pureMode) ? null : smMatcher(dat, matcherObj) const mc = (pureMode) ? null : smMatcher(dat, matcherObj)
nested[0] += mc ?? `<${dat}>` nested[0] += mc ?? `<${dat}>`

View File

@@ -1,3 +1,5 @@
import { tokenizeAccurate } from "../tokenizer";
export type Proompt = ProomptPlain|ProomptTyped|ProomptChat|ProomptAuthorNote; export type Proompt = ProomptPlain|ProomptTyped|ProomptChat|ProomptAuthorNote;
export interface ProomptPlain { export interface ProomptPlain {
@@ -23,4 +25,29 @@ export interface ProomptChat {
type: 'chat'; type: 'chat';
rangeStart: number; rangeStart: number;
rangeEnd: number|'end'; rangeEnd: number|'end';
}
export async function tokenizePreset(proompts:Proompt[]){
let total = 0
for(const proompt of proompts){
switch(proompt.type){
case 'plain':
case 'jailbreak':{
total += await tokenizeAccurate(proompt.text)
break
}
case 'persona':
case 'description':
case 'lorebook':
case 'postEverything':
case 'authornote':
case 'memory':{
if(proompt.innerFormat){
total += await tokenizeAccurate(proompt.innerFormat)
}
break
}
}
}
return total
} }

View File

@@ -4,6 +4,8 @@ import { DataBase, type character } from "./storage/database";
import { get } from "svelte/store"; import { get } from "svelte/store";
import type { OpenAIChat } from "./process"; import type { OpenAIChat } from "./process";
import { supportsInlayImage } from "./image"; import { supportsInlayImage } from "./image";
import { risuChatParser } from "./parser";
import type { Proompt } from "./process/proompt";
async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{ async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
let db = get(DataBase) let db = get(DataBase)
@@ -85,6 +87,14 @@ export async function tokenize(data:string) {
return encoded.length return encoded.length
} }
export async function tokenizeAccurate(data:string) {
data = risuChatParser(data.replace('{{slot}}',''), {
tokenizeAccurate: true
})
const encoded = await encode(data)
return encoded.length
}
export class ChatTokenizer { export class ChatTokenizer {
@@ -149,4 +159,4 @@ export class ChatTokenizer {
export async function tokenizeNum(data:string) { export async function tokenizeNum(data:string) {
const encoded = await encode(data) const encoded = await encode(data)
return encoded return encoded
} }