From dcebac658fa6653edca9a9f0ccf7ccdfde61663b Mon Sep 17 00:00:00 2001 From: kwaroran Date: Fri, 5 Jan 2024 23:29:57 +0900 Subject: [PATCH] [fix] embedding over 512 tokens --- src/ts/process/embedding/transformers.ts | 32 +++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/ts/process/embedding/transformers.ts b/src/ts/process/embedding/transformers.ts index cac64bda..716f67fd 100644 --- a/src/ts/process/embedding/transformers.ts +++ b/src/ts/process/embedding/transformers.ts @@ -1,4 +1,4 @@ -import type { Pipeline, PretrainedOptions } from '@xenova/transformers'; +import { AutoTokenizer, type Pipeline, type PretrainedOptions } from '@xenova/transformers'; let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise = null @@ -50,6 +50,36 @@ export const runSummarizer = async (text: string) => { export const runEmbedding = async (text: string):Promise => { await loadTransformer() let extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'); + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2'); + const tokens = tokenizer.encode(text) + if (tokens.length > 256) { + let chunks:string[] = [] + let chunk:number[] = [] + for (let i = 0; i < tokens.length; i++) { + if (chunk.length > 256) { + chunks.push(tokenizer.decode(chunk)) + chunk = [] + } + chunk.push(tokens[i]) + } + chunks.push(tokenizer.decode(chunk)) + let results:Float32Array[] = [] + for (let i = 0; i < chunks.length; i++) { + let result = await extractor(chunks[i], { pooling: 'mean', normalize: true }); + results.push(result?.data ?? null) + } + //set result, as average of all chunks + let result:Float32Array = new Float32Array(results[0].length) + for (let i = 0; i < results.length; i++) { + for (let j = 0; j < result.length; j++) { + result[j] += results[i][j] + } + } + for (let i = 0; i < result.length; i++) { + result[i] = Math.round(result[i] / results.length) + } + return result + } let result = await extractor(text, { pooling: 'mean', normalize: true }); return result?.data ?? null; }