[fix] embedding over 512 tokens
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
import type { Pipeline, PretrainedOptions } from '@xenova/transformers';
|
import { AutoTokenizer, type Pipeline, type PretrainedOptions } from '@xenova/transformers';
|
||||||
|
|
||||||
let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise<Pipeline> = null
|
let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise<Pipeline> = null
|
||||||
|
|
||||||
@@ -50,6 +50,36 @@ export const runSummarizer = async (text: string) => {
|
|||||||
export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
||||||
await loadTransformer()
|
await loadTransformer()
|
||||||
let extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
|
let extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
|
||||||
|
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2');
|
||||||
|
const tokens = tokenizer.encode(text)
|
||||||
|
if (tokens.length > 256) {
|
||||||
|
let chunks:string[] = []
|
||||||
|
let chunk:number[] = []
|
||||||
|
for (let i = 0; i < tokens.length; i++) {
|
||||||
|
if (chunk.length > 256) {
|
||||||
|
chunks.push(tokenizer.decode(chunk))
|
||||||
|
chunk = []
|
||||||
|
}
|
||||||
|
chunk.push(tokens[i])
|
||||||
|
}
|
||||||
|
chunks.push(tokenizer.decode(chunk))
|
||||||
|
let results:Float32Array[] = []
|
||||||
|
for (let i = 0; i < chunks.length; i++) {
|
||||||
|
let result = await extractor(chunks[i], { pooling: 'mean', normalize: true });
|
||||||
|
results.push(result?.data ?? null)
|
||||||
|
}
|
||||||
|
//set result, as average of all chunks
|
||||||
|
let result:Float32Array = new Float32Array(results[0].length)
|
||||||
|
for (let i = 0; i < results.length; i++) {
|
||||||
|
for (let j = 0; j < result.length; j++) {
|
||||||
|
result[j] += results[i][j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (let i = 0; i < result.length; i++) {
|
||||||
|
result[i] = Math.round(result[i] / results.length)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
let result = await extractor(text, { pooling: 'mean', normalize: true });
|
let result = await extractor(text, { pooling: 'mean', normalize: true });
|
||||||
return result?.data ?? null;
|
return result?.data ?? null;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user