Refactor runEmbedding function in transformers.ts
This commit is contained in:
@@ -57,39 +57,6 @@ export const runEmbedding = async (text: string, model:EmbeddingModel = 'Xenova/
|
|||||||
extractor = await pipeline('feature-extraction', model);
|
extractor = await pipeline('feature-extraction', model);
|
||||||
}
|
}
|
||||||
const tokenizer = await AutoTokenizer.from_pretrained(model);
|
const tokenizer = await AutoTokenizer.from_pretrained(model);
|
||||||
const tokens = tokenizer.encode(text)
|
|
||||||
if (tokens.length > 1024) {
|
|
||||||
let chunks:string[] = []
|
|
||||||
let chunk:number[] = []
|
|
||||||
for (let i = 0; i < tokens.length; i++) {
|
|
||||||
if (chunk.length > 256) {
|
|
||||||
chunks.push(tokenizer.decode(chunk))
|
|
||||||
chunk = []
|
|
||||||
}
|
|
||||||
chunk.push(tokens[i])
|
|
||||||
}
|
|
||||||
chunks.push(tokenizer.decode(chunk))
|
|
||||||
let results:Float32Array[] = []
|
|
||||||
for (let i = 0; i < chunks.length; i++) {
|
|
||||||
let result = await extractor(chunks[i], { pooling: 'mean', normalize: true });
|
|
||||||
const res:Float32Array = result?.data as Float32Array
|
|
||||||
|
|
||||||
if(res){
|
|
||||||
results.push(res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//set result, as average of all chunks
|
|
||||||
let result:Float32Array = new Float32Array(results[0].length)
|
|
||||||
for (let i = 0; i < results.length; i++) {
|
|
||||||
for (let j = 0; j < result.length; j++) {
|
|
||||||
result[j] += results[i][j]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (let i = 0; i < result.length; i++) {
|
|
||||||
result[i] = Math.round(result[i] / results.length)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
let result = await extractor(text, { pooling: 'mean', normalize: true });
|
let result = await extractor(text, { pooling: 'mean', normalize: true });
|
||||||
return (result?.data as Float32Array) ?? null;
|
return (result?.data as Float32Array) ?? null;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user