diff --git a/public/token/gemma/tokenizer.model b/public/token/gemma/tokenizer.model new file mode 100644 index 00000000..f83987c4 Binary files /dev/null and b/public/token/gemma/tokenizer.model differ diff --git a/src/ts/tokenizer.ts b/src/ts/tokenizer.ts index ca38b90d..77c980c8 100644 --- a/src/ts/tokenizer.ts +++ b/src/ts/tokenizer.ts @@ -6,6 +6,7 @@ import type { MultiModal, OpenAIChat } from "./process"; import { supportsInlayImage } from "./process/files/image"; import { risuChatParser } from "./parser"; import { tokenizeGGUFModel } from "./process/models/local"; +import { globalFetch } from "./storage/globalApi"; export const tokenizerList = [ @@ -78,11 +79,14 @@ export async function encode(data:string):Promise<(number[]|Uint32Array|Int32Arr if(db.aiModel.startsWith('gpt4o')){ return await tikJS(data, 'o200k_base') } + if(db.aiModel.startsWith('gemini')){ + return await tokenizeWebTokenizers(data, 'gemma') + } return await tikJS(data) } -type tokenizerType = 'novellist'|'claude'|'novelai'|'llama'|'mistral'|'llama3' +type tokenizerType = 'novellist'|'claude'|'novelai'|'llama'|'mistral'|'llama3'|'gemma' let tikParser:Tiktoken = null let tokenizersTokenizer:Tokenizer = null @@ -116,6 +120,31 @@ async function tikJS(text:string, model='cl100k_base') { return tikParser.encode(text) } +async function geminiTokenizer(text:string) { + const db = get(DataBase) + const fetchResult = await globalFetch(`https://generativelanguage.googleapis.com/v1beta/${db.aiModel}:countTextTokens`, { + "headers": { + "content-type": "application/json", + "authorization": `Bearer ${db.google.accessToken}` + }, + "body": JSON.stringify({ + "prompt":{ + text: text + } + }), + "method": "POST" + }) + + if(!fetchResult.ok){ + //fallback to tiktoken + return await tikJS(text) + } + + const result = fetchResult.data + + return result.tokenCount ?? 0 +} + async function tokenizeWebTokenizers(text:string, type:tokenizerType) { if(type !== tokenizersType || !tokenizersTokenizer){ const webTokenizer = await import('@mlc-ai/web-tokenizers') @@ -151,6 +180,11 @@ async function tokenizeWebTokenizers(text:string, type:tokenizerType) { await (await fetch("/token/mistral/tokenizer.model") ).arrayBuffer()) break + case 'gemma': + tokenizersTokenizer = await webTokenizer.Tokenizer.fromSentencePiece( + await (await fetch("/token/gemma/tokenizer.model") + ).arrayBuffer()) + break } tokenizersType = type