Add model selection for VitsModel

This commit is contained in:
kwaroran
2024-01-06 06:45:18 +09:00
parent 66c6511684
commit 7344e566f4
7 changed files with 366 additions and 220 deletions

View File

@@ -1,6 +1,11 @@
import {env, AutoTokenizer, pipeline, VitsModel, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers';
import {env, AutoTokenizer, pipeline, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers';
import { unzip } from 'fflate';
import { loadAsset, saveAsset } from 'src/ts/storage/globalApi';
import { selectSingleFile } from 'src/ts/util';
import { v4 } from 'uuid';
env.localModelPath = "https://sv.risuai.xyz/transformers/"
env.localModelPath = "/transformers/"
env.remoteHost = "https://sv.risuai.xyz/transformers/"
export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => {
let text = baseText
@@ -61,11 +66,49 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
let synthesizer:TextToAudioPipeline = null
let lastSynth:string = null
export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng') => {
export interface OnnxModelFiles {
files: {[key:string]:string},
id: string,
name?: string
}
export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'Xenova/mms-tts-eng') => {
const {WaveFile} = await import('wavefile')
if((!synthesizer) || (lastSynth !== model)){
lastSynth = model
synthesizer = await pipeline('text-to-speech', model);
if(modelData === null){
return
}
if(typeof modelData === 'string'){
if((!synthesizer) || (lastSynth !== modelData)){
lastSynth = modelData
synthesizer = await pipeline('text-to-speech', modelData);
}
}
else{
if((!synthesizer) || (lastSynth !== modelData.id)){
const files = modelData.files
const keys = Object.keys(files)
for(const key of keys){
const hasCache:boolean = (await (await fetch("/sw/check/", {
headers: {
'x-register-url': encodeURIComponent(key)
}
})).json()).able
if(!hasCache){
await fetch("/sw/register/", {
method: "POST",
body: await loadAsset(files[key]),
headers: {
'x-register-url': encodeURIComponent(key),
'x-no-content-type': 'true'
}
})
}
}
lastSynth = modelData.id
synthesizer = await pipeline('text-to-speech', modelData.id);
}
}
let out = await synthesizer(text, {});
const wav = new WaveFile();
@@ -77,4 +120,52 @@ export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng')
sourceNode.connect(audioContext.destination);
sourceNode.start();
});
}
}
export const registerOnnxModel = async ():Promise<OnnxModelFiles> => {
const id = v4().replace(/-/g, '')
const modelFile = await selectSingleFile(['zip'])
if(!modelFile){
return
}
const unziped = await new Promise((res, rej) => {unzip(modelFile.data, {
filter: (file) => {
return file.name.endsWith('.onnx') || file.size < 10_000_000 || file.name.includes('.git')
}
}, (err, unzipped) => {
if(err){
rej(err)
}
else{
res(unzipped)
}
})})
console.log(unziped)
let fileIdMapped:{[key:string]:string} = {}
const keys = Object.keys(unziped)
for(let i = 0; i < keys.length; i++){
const key = keys[i]
const file = unziped[key]
const fid = await saveAsset(file)
let url = key
if(url.startsWith('/')){
url = url.substring(1)
}
url = '/transformers/' + id +'/' + url
fileIdMapped[url] = fid
}
return {
files: fileIdMapped,
name: modelFile.name,
id: id,
}
}

View File

@@ -5,211 +5,171 @@ import { runTranslator, translateVox } from "../translator/translator";
import { globalFetch } from "../storage/globalApi";
import { language } from "src/lang";
import { getCurrentCharacter, sleep } from "../util";
import { runVITS } from "./embedding/transformers";
import { registerOnnxModel, runVITS } from "./embedding/transformers";
let sourceNode:AudioBufferSourceNode = null
export async function sayTTS(character:character,text:string) {
if(!character){
const v = getCurrentCharacter()
if(v.type === 'group'){
return
}
character = v
}
let db = get(DataBase)
text = text.replace(/\*/g,'')
if(character.ttsReadOnlyQuoted){
const matches = text.match(/"(.*?)"/g)
if(matches && matches.length > 0){
text = matches.map(match => match.slice(1, -1)).join("");
}
else{
text = ''
}
}
switch(character.ttsMode){
case "webspeech":{
if(speechSynthesis && SpeechSynthesisUtterance){
const utterThis = new SpeechSynthesisUtterance(text);
const voices = speechSynthesis.getVoices();
let voiceIndex = 0
for(let i=0;i<voices.length;i++){
if(voices[i].name === character.ttsSpeech){
voiceIndex = i
}
}
utterThis.voice = voices[voiceIndex]
const speak = speechSynthesis.speak(utterThis)
try {
if(!character){
const v = getCurrentCharacter()
if(v.type === 'group'){
return
}
break
character = v
}
case "elevenlab": {
const audioContext = new AudioContext();
const da = await fetch(`https://api.elevenlabs.io/v1/text-to-speech/${character.ttsSpeech}`, {
body: JSON.stringify({
text: text
}),
method: "POST",
headers: {
"Content-Type": "application/json",
'xi-api-key': db.elevenLabKey || undefined
}
})
if(da.status >= 200 && da.status < 300){
const audioBuffer = await audioContext.decodeAudioData(await da.arrayBuffer())
sourceNode = audioContext.createBufferSource();
sourceNode.buffer = audioBuffer;
sourceNode.connect(audioContext.destination);
sourceNode.start();
let db = get(DataBase)
text = text.replace(/\*/g,'')
if(character.ttsReadOnlyQuoted){
const matches = text.match(/"(.*?)"/g)
if(matches && matches.length > 0){
text = matches.map(match => match.slice(1, -1)).join("");
}
else{
alertError(await da.text())
text = ''
}
break
}
case "VOICEVOX": {
const jpText = await translateVox(text)
const audioContext = new AudioContext();
const query = await fetch(`${db.voicevoxUrl}/audio_query?text=${jpText}&speaker=${character.ttsSpeech}`, {
method: 'POST',
headers: { "Content-Type": "application/json"},
})
if (query.status == 200){
const queryJson = await query.json();
const bodyData = {
accent_phrases: queryJson.accent_phrases,
speedScale: character.voicevoxConfig.SPEED_SCALE,
pitchScale: character.voicevoxConfig.PITCH_SCALE,
volumeScale: character.voicevoxConfig.VOLUME_SCALE,
intonationScale: character.voicevoxConfig.INTONATION_SCALE,
prePhonemeLength: queryJson.prePhonemeLength,
postPhonemeLength: queryJson.postPhonemeLength,
outputSamplingRate: queryJson.outputSamplingRate,
outputStereo: queryJson.outputStereo,
kana: queryJson.kana,
switch(character.ttsMode){
case "webspeech":{
if(speechSynthesis && SpeechSynthesisUtterance){
const utterThis = new SpeechSynthesisUtterance(text);
const voices = speechSynthesis.getVoices();
let voiceIndex = 0
for(let i=0;i<voices.length;i++){
if(voices[i].name === character.ttsSpeech){
voiceIndex = i
}
}
utterThis.voice = voices[voiceIndex]
const speak = speechSynthesis.speak(utterThis)
}
const getVoice = await fetch(`${db.voicevoxUrl}/synthesis?speaker=${character.ttsSpeech}`, {
break
}
case "elevenlab": {
const audioContext = new AudioContext();
const da = await fetch(`https://api.elevenlabs.io/v1/text-to-speech/${character.ttsSpeech}`, {
body: JSON.stringify({
text: text
}),
method: "POST",
headers: {
"Content-Type": "application/json",
'xi-api-key': db.elevenLabKey || undefined
}
})
if(da.status >= 200 && da.status < 300){
const audioBuffer = await audioContext.decodeAudioData(await da.arrayBuffer())
sourceNode = audioContext.createBufferSource();
sourceNode.buffer = audioBuffer;
sourceNode.connect(audioContext.destination);
sourceNode.start();
}
else{
alertError(await da.text())
}
break
}
case "VOICEVOX": {
const jpText = await translateVox(text)
const audioContext = new AudioContext();
const query = await fetch(`${db.voicevoxUrl}/audio_query?text=${jpText}&speaker=${character.ttsSpeech}`, {
method: 'POST',
headers: { "Content-Type": "application/json"},
body: JSON.stringify(bodyData),
})
if (getVoice.status == 200 && getVoice.headers.get('content-type') === 'audio/wav'){
const audioBuffer = await audioContext.decodeAudioData(await getVoice.arrayBuffer())
sourceNode = audioContext.createBufferSource();
sourceNode.buffer = audioBuffer;
sourceNode.connect(audioContext.destination);
sourceNode.start();
}
}
break
}
case 'openai':{
const key = db.openAIKey
const res = await globalFetch('https://api.openai.com/v1/audio/speech', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + key,
},
body: {
model: 'tts-1',
input: text,
voice: character.oaiVoice,
},
rawResponse: true,
})
const dat = res.data
if(res.ok){
try {
const audio = Buffer.from(dat).buffer
const audioContext = new AudioContext();
const audioBuffer = await audioContext.decodeAudioData(audio)
sourceNode = audioContext.createBufferSource();
sourceNode.buffer = audioBuffer;
sourceNode.connect(audioContext.destination);
sourceNode.start();
} catch (error) {
alertError(language.errors.httpError + `${error}`)
}
}
else{
if(dat.error && dat.error.message){
alertError((language.errors.httpError + `${dat.error.message}`))
}
else{
alertError((language.errors.httpError + `${Buffer.from(res.data).toString()}`))
}
}
break;
}
case 'novelai': {
const audioContext = new AudioContext();
if(text === ''){
break;
}
const encodedText = encodeURIComponent(text);
const encodedSeed = encodeURIComponent(character.naittsConfig.voice);
const url = `https://api.novelai.net/ai/generate-voice?text=${encodedText}&voice=-1&seed=${encodedSeed}&opus=false&version=${character.naittsConfig.version}`;
const response = await globalFetch(url, {
method: 'GET',
headers: {
"Authorization": "Bearer " + db.NAIApiKey,
},
rawResponse: true
});
if (response.ok) {
const audioBuffer = response.data.buffer;
audioContext.decodeAudioData(audioBuffer, (decodedData) => {
const sourceNode = audioContext.createBufferSource();
sourceNode.buffer = decodedData;
sourceNode.connect(audioContext.destination);
sourceNode.start();
});
} else {
alertError("Error fetching or decoding audio data");
}
break;
}
case 'huggingface': {
while(true){
if(character.hfTTS.language !== 'en'){
text = await runTranslator(text, false, 'en', character.hfTTS.language)
}
const audioContext = new AudioContext();
const response = await fetch(`https://api-inference.huggingface.co/models/${character.hfTTS.model}`, {
method: 'POST',
headers: {
"Authorization": "Bearer " + db.huggingfaceKey,
"Content-Type": "application/json",
},
body: JSON.stringify({
inputs: text,
if (query.status == 200){
const queryJson = await query.json();
const bodyData = {
accent_phrases: queryJson.accent_phrases,
speedScale: character.voicevoxConfig.SPEED_SCALE,
pitchScale: character.voicevoxConfig.PITCH_SCALE,
volumeScale: character.voicevoxConfig.VOLUME_SCALE,
intonationScale: character.voicevoxConfig.INTONATION_SCALE,
prePhonemeLength: queryJson.prePhonemeLength,
postPhonemeLength: queryJson.postPhonemeLength,
outputSamplingRate: queryJson.outputSamplingRate,
outputStereo: queryJson.outputStereo,
kana: queryJson.kana,
}
const getVoice = await fetch(`${db.voicevoxUrl}/synthesis?speaker=${character.ttsSpeech}`, {
method: 'POST',
headers: { "Content-Type": "application/json"},
body: JSON.stringify(bodyData),
})
});
if(response.status === 503 && response.headers.get('content-type') === 'application/json'){
const json = await response.json()
if(json.estimated_time){
await sleep(json.estimated_time * 1000)
continue
if (getVoice.status == 200 && getVoice.headers.get('content-type') === 'audio/wav'){
const audioBuffer = await audioContext.decodeAudioData(await getVoice.arrayBuffer())
sourceNode = audioContext.createBufferSource();
sourceNode.buffer = audioBuffer;
sourceNode.connect(audioContext.destination);
sourceNode.start();
}
}
else if(response.status >= 400){
alertError(language.errors.httpError + `${await response.text()}`)
return
break
}
case 'openai':{
const key = db.openAIKey
const res = await globalFetch('https://api.openai.com/v1/audio/speech', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + key,
},
body: {
model: 'tts-1',
input: text,
voice: character.oaiVoice,
},
rawResponse: true,
})
const dat = res.data
if(res.ok){
try {
const audio = Buffer.from(dat).buffer
const audioContext = new AudioContext();
const audioBuffer = await audioContext.decodeAudioData(audio)
sourceNode = audioContext.createBufferSource();
sourceNode.buffer = audioBuffer;
sourceNode.connect(audioContext.destination);
sourceNode.start();
} catch (error) {
alertError(language.errors.httpError + `${error}`)
}
}
else if (response.status === 200) {
const audioBuffer = await response.arrayBuffer();
else{
if(dat.error && dat.error.message){
alertError((language.errors.httpError + `${dat.error.message}`))
}
else{
alertError((language.errors.httpError + `${Buffer.from(res.data).toString()}`))
}
}
break;
}
case 'novelai': {
const audioContext = new AudioContext();
if(text === ''){
break;
}
const encodedText = encodeURIComponent(text);
const encodedSeed = encodeURIComponent(character.naittsConfig.voice);
const url = `https://api.novelai.net/ai/generate-voice?text=${encodedText}&voice=-1&seed=${encodedSeed}&opus=false&version=${character.naittsConfig.version}`;
const response = await globalFetch(url, {
method: 'GET',
headers: {
"Authorization": "Bearer " + db.NAIApiKey,
},
rawResponse: true
});
if (response.ok) {
const audioBuffer = response.data.buffer;
audioContext.decodeAudioData(audioBuffer, (decodedData) => {
const sourceNode = audioContext.createBufferSource();
sourceNode.buffer = decodedData;
@@ -219,12 +179,56 @@ export async function sayTTS(character:character,text:string) {
} else {
alertError("Error fetching or decoding audio data");
}
return
break;
}
}
case 'vits':{
await runVITS(text)
}
case 'huggingface': {
while(true){
if(character.hfTTS.language !== 'en'){
text = await runTranslator(text, false, 'en', character.hfTTS.language)
}
const audioContext = new AudioContext();
const response = await fetch(`https://api-inference.huggingface.co/models/${character.hfTTS.model}`, {
method: 'POST',
headers: {
"Authorization": "Bearer " + db.huggingfaceKey,
"Content-Type": "application/json",
},
body: JSON.stringify({
inputs: text,
})
});
if(response.status === 503 && response.headers.get('content-type') === 'application/json'){
const json = await response.json()
if(json.estimated_time){
await sleep(json.estimated_time * 1000)
continue
}
}
else if(response.status >= 400){
alertError(language.errors.httpError + `${await response.text()}`)
return
}
else if (response.status === 200) {
const audioBuffer = await response.arrayBuffer();
audioContext.decodeAudioData(audioBuffer, (decodedData) => {
const sourceNode = audioContext.createBufferSource();
sourceNode.buffer = decodedData;
sourceNode.connect(audioContext.destination);
sourceNode.start();
});
} else {
alertError("Error fetching or decoding audio data");
}
return
}
}
case 'vits':{
await runVITS(text, character.vits)
}
}
} catch (error) {
alertError(`TTS Error: ${error}`)
}
}

View File

@@ -695,7 +695,8 @@ export interface character{
hfTTS?: {
model: string
language: string
}
},
vits?: OnnxModelFiles
}
@@ -1115,6 +1116,7 @@ export function setPreset(db:Database, newPres: botPreset){
import { encode as encodeMsgpack, decode as decodeMsgpack } from "msgpackr";
import * as fflate from "fflate";
import type { OnnxModelFiles } from '../process/embedding/transformers';
export async function downloadPreset(id:number){
saveCurrentPreset()

View File

@@ -11,7 +11,7 @@ import { checkOldDomain, checkUpdate } from "../update";
import { botMakerMode, selectedCharID } from "../stores";
import { Body, ResponseType, fetch as TauriFetch } from "@tauri-apps/api/http";
import { loadPlugins } from "../plugins/plugins";
import { alertConfirm, alertError } from "../alert";
import { alertConfirm, alertError, alertNormal } from "../alert";
import { checkDriverInit, syncDrive } from "../drive/drive";
import { hasher } from "../parser";
import { characterURLImport, hubURL } from "../characterCards";
@@ -231,6 +231,15 @@ export async function saveAsset(data:Uint8Array, customId:string = '', fileName:
}
}
export async function loadAsset(id:string){
if(isTauri){
return await readBinaryFile(id,{dir: BaseDirectory.AppData})
}
else{
return await forageStorage.getItem(id) as Uint8Array
}
}
let lastSave = ''
export async function saveDb(){
@@ -369,6 +378,7 @@ export async function loadData() {
throw "Your save file is corrupted"
}
}
await registerSw()
await checkUpdate()
await changeFullscreen()
@@ -432,15 +442,7 @@ export async function loadData() {
}
if(navigator.serviceWorker && (!Capacitor.isNativePlatform())){
usingSw = true
await navigator.serviceWorker.register("/sw.js", {
scope: "/"
});
await sleep(100)
const da = await fetch('/sw/init')
if(!(da.status >= 200 && da.status < 300)){
location.reload()
}
await registerSw()
}
else{
usingSw = false
@@ -792,6 +794,20 @@ export async function globalFetch(url:string, arg:{
}
}
async function registerSw() {
await navigator.serviceWorker.register("/sw.js", {
scope: "/"
});
await sleep(100)
const da = await fetch('/sw/init')
if(!(da.status >= 200 && da.status < 300)){
location.reload()
}
else{
}
}
const re = /\\/g
function getBasename(data:string){
const splited = data.replace(re, '/').split('/')
@@ -833,6 +849,13 @@ export function getUnpargeables(db:Database, uptype:'basename'|'pure' = 'basenam
addUnparge(em[1])
}
}
if(cha.vits){
const keys = Object.keys(cha.vits.files)
for(const key of keys){
const vit = cha.vits.files[key]
addUnparge(vit)
}
}
}
}
@@ -1044,7 +1067,7 @@ async function pargeChunks(){
const assets = await readDir('assets', {dir: BaseDirectory.AppData})
for(const asset of assets){
const n = getBasename(asset.name)
if(unpargeable.includes(n) || (!n.endsWith('png'))){
if(unpargeable.includes(n)){
}
else{
await removeFile(asset.path)
@@ -1054,8 +1077,11 @@ async function pargeChunks(){
else{
const indexes = await forageStorage.keys()
for(const asset of indexes){
if(!asset.startsWith('assets/')){
continue
}
const n = getBasename(asset)
if(unpargeable.includes(n) || (!asset.endsWith(".png"))){
if(unpargeable.includes(n)){
}
else{
await forageStorage.removeItem(asset)