Add initialization function for Transformers and update model paths

This commit is contained in:
kwaroran
2024-01-06 20:49:34 +09:00
parent 85cd1cbc65
commit 46ee706399

View File

@@ -4,9 +4,21 @@ import { loadAsset, saveAsset } from 'src/ts/storage/globalApi';
import { selectSingleFile } from 'src/ts/util'; import { selectSingleFile } from 'src/ts/util';
import { v4 } from 'uuid'; import { v4 } from 'uuid';
env.localModelPath = "https://sv.risuai.xyz/transformers/" const tfCache = new Cache()
let tfLoaded = false
async function initTransformers(){
if(tfLoaded){
return
}
const tfCache = new Cache()
env.localModelPath = "/tf/"
env.remoteHost = "https://sv.risuai.xyz/transformers/"
env.customCache = await caches.open('transformers')
tfLoaded = true
}
export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => { export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => {
await initTransformers()
let text = baseText let text = baseText
let generator = await pipeline('text-generation', model); let generator = await pipeline('text-generation', model);
let output = await generator(text, config) as TextGenerationOutput let output = await generator(text, config) as TextGenerationOutput
@@ -15,6 +27,7 @@ export const runTransformers = async (baseText:string, model:string,config:TextG
} }
export const runSummarizer = async (text: string) => { export const runSummarizer = async (text: string) => {
await initTransformers()
let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6") let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6")
const v = await classifier(text) as SummarizationOutput const v = await classifier(text) as SummarizationOutput
return v[0].summary_text return v[0].summary_text
@@ -22,6 +35,7 @@ export const runSummarizer = async (text: string) => {
let extractor:FeatureExtractionPipeline = null let extractor:FeatureExtractionPipeline = null
export const runEmbedding = async (text: string):Promise<Float32Array> => { export const runEmbedding = async (text: string):Promise<Float32Array> => {
await initTransformers()
if(!extractor){ if(!extractor){
extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'); extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
} }
@@ -73,6 +87,7 @@ export interface OnnxModelFiles {
} }
export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'Xenova/mms-tts-eng') => { export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'Xenova/mms-tts-eng') => {
await initTransformers()
const {WaveFile} = await import('wavefile') const {WaveFile} = await import('wavefile')
if(modelData === null){ if(modelData === null){
return return
@@ -88,22 +103,7 @@ export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'X
const files = modelData.files const files = modelData.files
const keys = Object.keys(files) const keys = Object.keys(files)
for(const key of keys){ for(const key of keys){
const hasCache:boolean = (await (await fetch("/sw/check/", { tfCache.put(key, new Response(await loadAsset(files[key])))
headers: {
'x-register-url': encodeURIComponent(key)
}
})).json()).able
if(!hasCache){
await fetch("/sw/register/", {
method: "POST",
body: await loadAsset(files[key]),
headers: {
'x-register-url': encodeURIComponent(key),
'x-no-content-type': 'true'
}
})
}
} }
lastSynth = modelData.id lastSynth = modelData.id
synthesizer = await pipeline('text-to-speech', modelData.id); synthesizer = await pipeline('text-to-speech', modelData.id);
@@ -157,7 +157,7 @@ export const registerOnnxModel = async ():Promise<OnnxModelFiles> => {
if(url.startsWith('/')){ if(url.startsWith('/')){
url = url.substring(1) url = url.substring(1)
} }
url = '/transformers/' + id +'/' + url url = '/tf/' + id +'/' + url
fileIdMapped[url] = fid fileIdMapped[url] = fid
} }