Refactor Transformers
This commit is contained in:
@@ -1,14 +1,7 @@
|
|||||||
import { AutoTokenizer, type Pipeline, type PretrainedOptions } from '@xenova/transformers';
|
import transformers, { AutoTokenizer, Pipeline, pipeline, type DataArray, type SummarizationOutput } from '@xenova/transformers';
|
||||||
|
|
||||||
let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise<Pipeline> = null
|
transformers.env.localModelPath = "https://sv.risuai.xyz/transformers/"
|
||||||
|
|
||||||
async function loadTransformer() {
|
|
||||||
if(!pipeline){
|
|
||||||
const transformersLib = await import('@xenova/transformers')
|
|
||||||
transformersLib.env.localModelPath = "https://sv.risuai.xyz/transformers/"
|
|
||||||
pipeline = transformersLib.pipeline
|
|
||||||
}
|
|
||||||
}
|
|
||||||
type TransformersBodyType = {
|
type TransformersBodyType = {
|
||||||
max_new_tokens: number,
|
max_new_tokens: number,
|
||||||
do_sample: boolean,
|
do_sample: boolean,
|
||||||
@@ -33,23 +26,21 @@ type TransformersBodyType = {
|
|||||||
|
|
||||||
|
|
||||||
export const runTransformers = async (baseText:string, model:string,bodyTemplate:TransformersBodyType) => {
|
export const runTransformers = async (baseText:string, model:string,bodyTemplate:TransformersBodyType) => {
|
||||||
await loadTransformer()
|
|
||||||
let text = baseText
|
let text = baseText
|
||||||
let generator = await pipeline('text-generation', model);
|
let generator = await pipeline('text-generation', model);
|
||||||
let output:{generated_text:string}[] = await generator(text);
|
let output = await generator(text) as transformers.TextGenerationOutput
|
||||||
return output
|
const outputOne = output[0]
|
||||||
|
return outputOne
|
||||||
}
|
}
|
||||||
|
|
||||||
export const runSummarizer = async (text: string) => {
|
export const runSummarizer = async (text: string) => {
|
||||||
await loadTransformer()
|
|
||||||
let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6")
|
let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6")
|
||||||
const v:{summary_text:string}[] = await classifier(text)
|
const v = await classifier(text) as SummarizationOutput
|
||||||
return v
|
return v[0].summary_text
|
||||||
}
|
}
|
||||||
|
|
||||||
let extractor:Pipeline = null
|
let extractor:transformers.FeatureExtractionPipeline = null
|
||||||
export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
||||||
await loadTransformer()
|
|
||||||
if(!extractor){
|
if(!extractor){
|
||||||
extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
|
extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
|
||||||
}
|
}
|
||||||
@@ -69,7 +60,11 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
|||||||
let results:Float32Array[] = []
|
let results:Float32Array[] = []
|
||||||
for (let i = 0; i < chunks.length; i++) {
|
for (let i = 0; i < chunks.length; i++) {
|
||||||
let result = await extractor(chunks[i], { pooling: 'mean', normalize: true });
|
let result = await extractor(chunks[i], { pooling: 'mean', normalize: true });
|
||||||
results.push(result?.data ?? null)
|
const res:Float32Array = result?.data as Float32Array
|
||||||
|
|
||||||
|
if(res){
|
||||||
|
results.push(res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
//set result, as average of all chunks
|
//set result, as average of all chunks
|
||||||
let result:Float32Array = new Float32Array(results[0].length)
|
let result:Float32Array = new Float32Array(results[0].length)
|
||||||
@@ -84,7 +79,7 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
let result = await extractor(text, { pooling: 'mean', normalize: true });
|
let result = await extractor(text, { pooling: 'mean', normalize: true });
|
||||||
return result?.data ?? null;
|
return (result?.data as Float32Array) ?? null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const runTTS = async (text: string) => {
|
export const runTTS = async (text: string) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user