Add initialization function for Transformers and update model paths
This commit is contained in:
@@ -4,9 +4,21 @@ import { loadAsset, saveAsset } from 'src/ts/storage/globalApi';
|
||||
import { selectSingleFile } from 'src/ts/util';
|
||||
import { v4 } from 'uuid';
|
||||
|
||||
env.localModelPath = "https://sv.risuai.xyz/transformers/"
|
||||
const tfCache = new Cache()
|
||||
let tfLoaded = false
|
||||
async function initTransformers(){
|
||||
if(tfLoaded){
|
||||
return
|
||||
}
|
||||
const tfCache = new Cache()
|
||||
env.localModelPath = "/tf/"
|
||||
env.remoteHost = "https://sv.risuai.xyz/transformers/"
|
||||
env.customCache = await caches.open('transformers')
|
||||
tfLoaded = true
|
||||
}
|
||||
|
||||
export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => {
|
||||
await initTransformers()
|
||||
let text = baseText
|
||||
let generator = await pipeline('text-generation', model);
|
||||
let output = await generator(text, config) as TextGenerationOutput
|
||||
@@ -15,6 +27,7 @@ export const runTransformers = async (baseText:string, model:string,config:TextG
|
||||
}
|
||||
|
||||
export const runSummarizer = async (text: string) => {
|
||||
await initTransformers()
|
||||
let classifier = await pipeline("summarization", "Xenova/distilbart-cnn-6-6")
|
||||
const v = await classifier(text) as SummarizationOutput
|
||||
return v[0].summary_text
|
||||
@@ -22,6 +35,7 @@ export const runSummarizer = async (text: string) => {
|
||||
|
||||
let extractor:FeatureExtractionPipeline = null
|
||||
export const runEmbedding = async (text: string):Promise<Float32Array> => {
|
||||
await initTransformers()
|
||||
if(!extractor){
|
||||
extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
|
||||
}
|
||||
@@ -73,6 +87,7 @@ export interface OnnxModelFiles {
|
||||
}
|
||||
|
||||
export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'Xenova/mms-tts-eng') => {
|
||||
await initTransformers()
|
||||
const {WaveFile} = await import('wavefile')
|
||||
if(modelData === null){
|
||||
return
|
||||
@@ -88,22 +103,7 @@ export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'X
|
||||
const files = modelData.files
|
||||
const keys = Object.keys(files)
|
||||
for(const key of keys){
|
||||
const hasCache:boolean = (await (await fetch("/sw/check/", {
|
||||
headers: {
|
||||
'x-register-url': encodeURIComponent(key)
|
||||
}
|
||||
})).json()).able
|
||||
|
||||
if(!hasCache){
|
||||
await fetch("/sw/register/", {
|
||||
method: "POST",
|
||||
body: await loadAsset(files[key]),
|
||||
headers: {
|
||||
'x-register-url': encodeURIComponent(key),
|
||||
'x-no-content-type': 'true'
|
||||
}
|
||||
})
|
||||
}
|
||||
tfCache.put(key, new Response(await loadAsset(files[key])))
|
||||
}
|
||||
lastSynth = modelData.id
|
||||
synthesizer = await pipeline('text-to-speech', modelData.id);
|
||||
@@ -157,7 +157,7 @@ export const registerOnnxModel = async ():Promise<OnnxModelFiles> => {
|
||||
if(url.startsWith('/')){
|
||||
url = url.substring(1)
|
||||
}
|
||||
url = '/transformers/' + id +'/' + url
|
||||
url = '/tf/' + id +'/' + url
|
||||
fileIdMapped[url] = fid
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user