From 9974ff0df10f873521ca1e1a34232035389652ac Mon Sep 17 00:00:00 2001 From: kwaroran Date: Wed, 7 Jun 2023 09:23:29 +0900 Subject: [PATCH] [refactor] transformer tokenizer --- src/ts/tokenizer.ts | 2 +- src/ts/transformers/transformer.ts | 12 +++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/ts/tokenizer.ts b/src/ts/tokenizer.ts index 32797555..032a9460 100644 --- a/src/ts/tokenizer.ts +++ b/src/ts/tokenizer.ts @@ -6,7 +6,7 @@ import { tokenizeTransformers } from "./transformers/transformer"; async function encode(data:string):Promise<(number[]|Uint32Array)>{ let db = get(DataBase) if(db.aiModel === 'novellist'){ - return await tokenizeTransformers('trin',data) + return await tokenizeTransformers('naclbit/trin_tokenizer_v3',data) } return await tikJS(data) } diff --git a/src/ts/transformers/transformer.ts b/src/ts/transformers/transformer.ts index b95e40b7..893f2666 100644 --- a/src/ts/transformers/transformer.ts +++ b/src/ts/transformers/transformer.ts @@ -3,13 +3,7 @@ type transformerLibType = typeof import("@xenova/transformers"); let tokenizer:PreTrainedTokenizer = null let transformerLib:transformerLibType -const tokenizerDict = { - 'trin': 'naclbit/trin_tokenizer_v3', -} as const - -type tokenizerTypes = keyof(typeof tokenizerDict) - -let tokenizerType:tokenizerTypes|'' = '' +let tokenizerType:string = '' async function loadTransformers() { @@ -18,11 +12,11 @@ async function loadTransformers() { } } -export async function tokenizeTransformers(type:tokenizerTypes, text:string) { +export async function tokenizeTransformers(type:string, text:string) { await loadTransformers() if(tokenizerType !== type){ const AutoTokenizer = transformerLib.AutoTokenizer - tokenizer = await AutoTokenizer.from_pretrained(tokenizerDict[type]) + tokenizer = await AutoTokenizer.from_pretrained(type) tokenizerType = type }