From 34b4a1245b2c6f1e675ec56fc7b895647465b399 Mon Sep 17 00:00:00 2001 From: kwaroran Date: Sat, 7 Dec 2024 03:49:56 +0900 Subject: [PATCH] Add google cloud tokenizer --- src/ts/tokenizer.ts | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/ts/tokenizer.ts b/src/ts/tokenizer.ts index d767481d..ba49e19c 100644 --- a/src/ts/tokenizer.ts +++ b/src/ts/tokenizer.ts @@ -6,6 +6,7 @@ import { supportsInlayImage } from "./process/files/image"; import { risuChatParser } from "./parser.svelte"; import { tokenizeGGUFModel } from "./process/models/local"; import { globalFetch } from "./globalApi.svelte"; +import { getModelInfo } from "./model/modellist"; export const tokenizerList = [ @@ -80,7 +81,10 @@ export async function encode(data:string):Promise<(number[]|Uint32Array|Int32Arr return await tikJS(data, 'o200k_base') } if(db.aiModel.startsWith('gemini')){ - return await tokenizeWebTokenizers(data, 'gemma') + if(db.aiModel.endsWith('-vertex')){ + return await tokenizeWebTokenizers(data, 'gemma') + } + return await tokenizeGoogleCloud(data) } if(db.aiModel.startsWith('cohere')){ return await tokenizeWebTokenizers(data, 'cohere') @@ -89,13 +93,49 @@ export async function encode(data:string):Promise<(number[]|Uint32Array|Int32Arr return await tikJS(data) } -type tokenizerType = 'novellist'|'claude'|'novelai'|'llama'|'mistral'|'llama3'|'gemma'|'cohere' +type tokenizerType = 'novellist'|'claude'|'novelai'|'llama'|'mistral'|'llama3'|'gemma'|'cohere'|'googleCloud' let tikParser:Tiktoken = null let tokenizersTokenizer:Tokenizer = null let tokenizersType:tokenizerType = null let lastTikModel = 'cl100k_base' +let googleCloudTokenizedCache = new Map() + +async function tokenizeGoogleCloud(text:string) { + const db = getDatabase() + const model = getModelInfo(db.aiModel) + + if(googleCloudTokenizedCache.has(text + model.internalID)){ + const count = googleCloudTokenizedCache.get(text) + return new Uint32Array(count) + } + + const res = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${model.internalID}:countTokens?key=${db.google?.accessToken}`, { + method: 'POST', + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + contents: [{ + parts:[{ + text: text + }] + }] + }), + }) + + if(res.status !== 200){ + return await tokenizeWebTokenizers(text, 'gemma') + } + + const json = await res.json() + googleCloudTokenizedCache.set(text + model.internalID, json.totalTokens as number) + const count = json.totalTokens as number + + return new Uint32Array(count) +} + async function tikJS(text:string, model='cl100k_base') { if(!tikParser || lastTikModel !== model){ if(model === 'cl100k_base'){