Refactor Transformers

This commit is contained in:
kwaroran
2024-01-05 23:51:37 +09:00
parent b6787b93a7
commit 5ae3f35df6

View File

@@ -1,14 +1,7 @@
import { AutoTokenizer, type Pipeline, type PretrainedOptions } from '@xenova/transformers'; import transformers, { AutoTokenizer, Pipeline, pipeline, type DataArray, type SummarizationOutput } from '@xenova/transformers';
let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise<Pipeline> = null transformers.env.localModelPath = "https://sv.risuai.xyz/transformers/"
async function loadTransformer() {
if(!pipeline){
const transformersLib = await import('@xenova/transformers')
transformersLib.env.localModelPath = "https://sv.risuai.xyz/transformers/"
pipeline = transformersLib.pipeline
}
}
type TransformersBodyType = { type TransformersBodyType = {
max_new_tokens: number, max_new_tokens: number,
do_sample: boolean, do_sample: boolean,
@@ -33,23 +26,21 @@ type TransformersBodyType = {
export const runTransformers = async (baseText:string, model:string,bodyTemplate:TransformersBodyType) => { export const runTransformers = async (baseText:string, model:string,bodyTemplate:TransformersBodyType) => {
await loadTransformer()
let text = baseText let text = baseText
let generator = await pipeline('text-generation', model); let generator = await pipeline('text-generation', model);
let output:{generated_text:string}[] = await generator(text); let output = await generator(text) as transformers.TextGenerationOutput
return output const outputOne = output[0]
return outputOne
} }
export const runSummarizer = async (text: string) => { export const runSummarizer = async (text: string) => {
await loadTransformer()
let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6") let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6")
const v:{summary_text:string}[] = await classifier(text) const v = await classifier(text) as SummarizationOutput
return v return v[0].summary_text
} }
let extractor:Pipeline = null let extractor:transformers.FeatureExtractionPipeline = null
export const runEmbedding = async (text: string):Promise<Float32Array> => { export const runEmbedding = async (text: string):Promise<Float32Array> => {
await loadTransformer()
if(!extractor){ if(!extractor){
extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'); extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
} }
@@ -69,7 +60,11 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
let results:Float32Array[] = [] let results:Float32Array[] = []
for (let i = 0; i < chunks.length; i++) { for (let i = 0; i < chunks.length; i++) {
let result = await extractor(chunks[i], { pooling: 'mean', normalize: true }); let result = await extractor(chunks[i], { pooling: 'mean', normalize: true });
results.push(result?.data ?? null) const res:Float32Array = result?.data as Float32Array
if(res){
results.push(res)
}
} }
//set result, as average of all chunks //set result, as average of all chunks
let result:Float32Array = new Float32Array(results[0].length) let result:Float32Array = new Float32Array(results[0].length)
@@ -84,7 +79,7 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
return result return result
} }
let result = await extractor(text, { pooling: 'mean', normalize: true }); let result = await extractor(text, { pooling: 'mean', normalize: true });
return result?.data ?? null; return (result?.data as Float32Array) ?? null;
} }
export const runTTS = async (text: string) => { export const runTTS = async (text: string) => {