[feat] added transformers

This commit is contained in:
kwaroran
2023-08-05 23:08:39 +09:00
parent 85c240a80b
commit aae5085f9f
4 changed files with 465 additions and 9 deletions

View File

@@ -0,0 +1,57 @@
import type { Pipeline, PretrainedOptions } from '@xenova/transformers';
import { DataBase } from 'src/ts/storage/database';
import { get } from 'svelte/store';
let pipeline: (task: string, model?: string, { quantized, progress_callback, config, cache_dir, local_files_only, revision, }?: PretrainedOptions) => Promise<Pipeline> = null
async function loadTransformer() {
if(!pipeline){
const transformersLib = await import('@xenova/transformers')
pipeline = transformersLib.pipeline
}
}
type TransformersBodyType = {
max_new_tokens: number,
do_sample: boolean,
temperature: number,
top_p: number,
typical_p: number,
repetition_penalty: number,
encoder_repetition_penalty: number,
top_k: number,
min_length: number,
no_repeat_ngram_size: number,
num_beams: number,
penalty_alpha: number,
length_penalty: number,
early_stopping: boolean,
truncation_length: number,
ban_eos_token: boolean,
stopping_strings: number,
seed: number,
add_bos_token: boolean,
}
export const runTransformers = async (baseText:string, model:string,bodyTemplate:TransformersBodyType) => {
await loadTransformer()
let text = baseText
let generator = await pipeline('text-generation', model);
let output = await generator(text, bodyTemplate);
return output
}
export const runSummarizer = async (text: string) => {
await loadTransformer()
let classifier = await pipeline("summarization", "Xenova/bart-large-cnn")
const v:{summary_text:string}[] = await classifier(text)
return v
}
export const runEmbedding = async (text: string):Promise<Float32Array> => {
await loadTransformer()
let extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
let result = await extractor(text, { pooling: 'mean', normalize: true });
return result?.data ?? null;
}

View File

@@ -140,10 +140,6 @@ export function processScriptFull(char:character|groupChat, data:string, mode:Sc
}
}
else{
if(randomness.test(data)){
const list = data.split('|||')
data = list[Math.floor(Math.random()*list.length)];
}
data = data.replace(reg, outScript)
}
}