Refactor Transformers
This commit is contained in:
@@ -1,14 +1,7 @@
|
||||
import { AutoTokenizer, type Pipeline, type PretrainedOptions } from '@xenova/transformers';
|
||||
import transformers, { AutoTokenizer, Pipeline, pipeline, type DataArray, type SummarizationOutput } from '@xenova/transformers';
|
||||
|
||||
let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise<Pipeline> = null
|
||||
transformers.env.localModelPath = "https://sv.risuai.xyz/transformers/"
|
||||
|
||||
async function loadTransformer() {
|
||||
if(!pipeline){
|
||||
const transformersLib = await import('@xenova/transformers')
|
||||
transformersLib.env.localModelPath = "https://sv.risuai.xyz/transformers/"
|
||||
pipeline = transformersLib.pipeline
|
||||
}
|
||||
}
|
||||
type TransformersBodyType = {
|
||||
max_new_tokens: number,
|
||||
do_sample: boolean,
|
||||
@@ -33,23 +26,21 @@ type TransformersBodyType = {
|
||||
|
||||
|
||||
export const runTransformers = async (baseText:string, model:string,bodyTemplate:TransformersBodyType) => {
|
||||
await loadTransformer()
|
||||
let text = baseText
|
||||
let generator = await pipeline('text-generation', model);
|
||||
let output:{generated_text:string}[] = await generator(text);
|
||||
return output
|
||||
let output = await generator(text) as transformers.TextGenerationOutput
|
||||
const outputOne = output[0]
|
||||
return outputOne
|
||||
}
|
||||
|
||||
export const runSummarizer = async (text: string) => {
|
||||
await loadTransformer()
|
||||
let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6")
|
||||
const v:{summary_text:string}[] = await classifier(text)
|
||||
return v
|
||||
const v = await classifier(text) as SummarizationOutput
|
||||
return v[0].summary_text
|
||||
}
|
||||
|
||||
let extractor:Pipeline = null
|
||||
let extractor:transformers.FeatureExtractionPipeline = null
|
||||
export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
||||
await loadTransformer()
|
||||
if(!extractor){
|
||||
extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
|
||||
}
|
||||
@@ -69,7 +60,11 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
||||
let results:Float32Array[] = []
|
||||
for (let i = 0; i < chunks.length; i++) {
|
||||
let result = await extractor(chunks[i], { pooling: 'mean', normalize: true });
|
||||
results.push(result?.data ?? null)
|
||||
const res:Float32Array = result?.data as Float32Array
|
||||
|
||||
if(res){
|
||||
results.push(res)
|
||||
}
|
||||
}
|
||||
//set result, as average of all chunks
|
||||
let result:Float32Array = new Float32Array(results[0].length)
|
||||
@@ -84,7 +79,7 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
||||
return result
|
||||
}
|
||||
let result = await extractor(text, { pooling: 'mean', normalize: true });
|
||||
return result?.data ?? null;
|
||||
return (result?.data as Float32Array) ?? null;
|
||||
}
|
||||
|
||||
export const runTTS = async (text: string) => {
|
||||
|
||||
Reference in New Issue
Block a user