diff --git a/src/ts/process/embedding/transformers.ts b/src/ts/process/embedding/transformers.ts index 619fc384..b3706bb2 100644 --- a/src/ts/process/embedding/transformers.ts +++ b/src/ts/process/embedding/transformers.ts @@ -1,14 +1,7 @@ -import { AutoTokenizer, type Pipeline, type PretrainedOptions } from '@xenova/transformers'; +import transformers, { AutoTokenizer, Pipeline, pipeline, type DataArray, type SummarizationOutput } from '@xenova/transformers'; -let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise = null +transformers.env.localModelPath = "https://sv.risuai.xyz/transformers/" -async function loadTransformer() { - if(!pipeline){ - const transformersLib = await import('@xenova/transformers') - transformersLib.env.localModelPath = "https://sv.risuai.xyz/transformers/" - pipeline = transformersLib.pipeline - } -} type TransformersBodyType = { max_new_tokens: number, do_sample: boolean, @@ -33,23 +26,21 @@ type TransformersBodyType = { export const runTransformers = async (baseText:string, model:string,bodyTemplate:TransformersBodyType) => { - await loadTransformer() let text = baseText let generator = await pipeline('text-generation', model); - let output:{generated_text:string}[] = await generator(text); - return output + let output = await generator(text) as transformers.TextGenerationOutput + const outputOne = output[0] + return outputOne } export const runSummarizer = async (text: string) => { - await loadTransformer() let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6") - const v:{summary_text:string}[] = await classifier(text) - return v + const v = await classifier(text) as SummarizationOutput + return v[0].summary_text } -let extractor:Pipeline = null +let extractor:transformers.FeatureExtractionPipeline = null export const runEmbedding = async (text: string):Promise => { - await loadTransformer() if(!extractor){ extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'); } @@ -69,7 +60,11 @@ export const runEmbedding = async (text: string):Promise => { let results:Float32Array[] = [] for (let i = 0; i < chunks.length; i++) { let result = await extractor(chunks[i], { pooling: 'mean', normalize: true }); - results.push(result?.data ?? null) + const res:Float32Array = result?.data as Float32Array + + if(res){ + results.push(res) + } } //set result, as average of all chunks let result:Float32Array = new Float32Array(results[0].length) @@ -84,7 +79,7 @@ export const runEmbedding = async (text: string):Promise => { return result } let result = await extractor(text, { pooling: 'mean', normalize: true }); - return result?.data ?? null; + return (result?.data as Float32Array) ?? null; } export const runTTS = async (text: string) => {