Add model selection for VitsModel

This commit is contained in:
kwaroran
2024-01-06 06:45:18 +09:00
parent 66c6511684
commit 7344e566f4
7 changed files with 366 additions and 220 deletions

View File

@@ -1,6 +1,11 @@
import {env, AutoTokenizer, pipeline, VitsModel, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers';
import {env, AutoTokenizer, pipeline, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers';
import { unzip } from 'fflate';
import { loadAsset, saveAsset } from 'src/ts/storage/globalApi';
import { selectSingleFile } from 'src/ts/util';
import { v4 } from 'uuid';
env.localModelPath = "https://sv.risuai.xyz/transformers/"
env.localModelPath = "/transformers/"
env.remoteHost = "https://sv.risuai.xyz/transformers/"
export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => {
let text = baseText
@@ -61,11 +66,49 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
let synthesizer:TextToAudioPipeline = null
let lastSynth:string = null
export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng') => {
export interface OnnxModelFiles {
files: {[key:string]:string},
id: string,
name?: string
}
export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'Xenova/mms-tts-eng') => {
const {WaveFile} = await import('wavefile')
if((!synthesizer) || (lastSynth !== model)){
lastSynth = model
synthesizer = await pipeline('text-to-speech', model);
if(modelData === null){
return
}
if(typeof modelData === 'string'){
if((!synthesizer) || (lastSynth !== modelData)){
lastSynth = modelData
synthesizer = await pipeline('text-to-speech', modelData);
}
}
else{
if((!synthesizer) || (lastSynth !== modelData.id)){
const files = modelData.files
const keys = Object.keys(files)
for(const key of keys){
const hasCache:boolean = (await (await fetch("/sw/check/", {
headers: {
'x-register-url': encodeURIComponent(key)
}
})).json()).able
if(!hasCache){
await fetch("/sw/register/", {
method: "POST",
body: await loadAsset(files[key]),
headers: {
'x-register-url': encodeURIComponent(key),
'x-no-content-type': 'true'
}
})
}
}
lastSynth = modelData.id
synthesizer = await pipeline('text-to-speech', modelData.id);
}
}
let out = await synthesizer(text, {});
const wav = new WaveFile();
@@ -77,4 +120,52 @@ export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng')
sourceNode.connect(audioContext.destination);
sourceNode.start();
});
}
}
export const registerOnnxModel = async ():Promise<OnnxModelFiles> => {
const id = v4().replace(/-/g, '')
const modelFile = await selectSingleFile(['zip'])
if(!modelFile){
return
}
const unziped = await new Promise((res, rej) => {unzip(modelFile.data, {
filter: (file) => {
return file.name.endsWith('.onnx') || file.size < 10_000_000 || file.name.includes('.git')
}
}, (err, unzipped) => {
if(err){
rej(err)
}
else{
res(unzipped)
}
})})
console.log(unziped)
let fileIdMapped:{[key:string]:string} = {}
const keys = Object.keys(unziped)
for(let i = 0; i < keys.length; i++){
const key = keys[i]
const file = unziped[key]
const fid = await saveAsset(file)
let url = key
if(url.startsWith('/')){
url = url.substring(1)
}
url = '/transformers/' + id +'/' + url
fileIdMapped[url] = fid
}
return {
files: fileIdMapped,
name: modelFile.name,
id: id,
}
}