From 7344e566f4f653be83e1eb04ec924f637d92e9b8 Mon Sep 17 00:00:00 2001 From: kwaroran Date: Sat, 6 Jan 2024 06:45:18 +0900 Subject: [PATCH] Add model selection for VitsModel --- public/sw.js | 24 +- src/lang/en.ts | 1 + src/lib/SideBars/CharConfig.svelte | 14 + src/ts/process/embedding/transformers.ts | 105 +++++- src/ts/process/tts.ts | 388 ++++++++++++----------- src/ts/storage/database.ts | 4 +- src/ts/storage/globalApi.ts | 50 ++- 7 files changed, 366 insertions(+), 220 deletions(-) diff --git a/public/sw.js b/public/sw.js index 87117a85..938a761b 100644 --- a/public/sw.js +++ b/public/sw.js @@ -7,7 +7,13 @@ self.addEventListener('fetch', (event) => { try { switch (path[2]){ case "check":{ - event.respondWith(checkCache(url)) + let targetUrl = url + const headers = event.request.headers + const headerUrl = headers.get('x-register-url') + if(headerUrl){ + targetUrl.pathname = decodeURIComponent(headerUrl) + } + event.respondWith(checkCache(targetUrl)) break } case "img": { @@ -15,20 +21,20 @@ self.addEventListener('fetch', (event) => { break } case "register": { - let targerUrl = url + let targetUrl = url const headers = event.request.headers const headerUrl = headers.get('x-register-url') if(headerUrl){ - targerUrl = new URL(headerUrl) + targetUrl.pathname = decodeURIComponent(headerUrl) } const noContentType = headers.get('x-no-content-type') === 'true' event.respondWith( - registerCache(targerUrl, event.request.arrayBuffer(), noContentType) + registerCache(targetUrl, event.request.arrayBuffer(), noContentType) ) break } case "init":{ - event.respondWith(new Response("true")) + event.respondWith(new Response("v2")) } default: { event.respondWith(new Response( @@ -74,9 +80,11 @@ async function check(){ async function registerCache(urlr, buffer, noContentType = false){ const cache = await caches.open('risuCache') const url = new URL(urlr) - let path = url.pathname.split('/') - path[2] = 'img' - url.pathname = path.join('/') + if(!noContentType){ + let path = url.pathname.split('/') + path[2] = 'img' + url.pathname = path.join('/') + } const buf = new Uint8Array(await buffer) let headers = { "cache-control": "max-age=604800", diff --git a/src/lang/en.ts b/src/lang/en.ts index d993eed3..fc3ce605 100644 --- a/src/lang/en.ts +++ b/src/lang/en.ts @@ -484,4 +484,5 @@ export const languageEnglish = { chatAsOriginalOnSystem: "Send as original role", exportAsDataset: "Export Save as Dataset", editTranslationDisplay: "Edit Translation Display", + selectModel: "Select Model", } \ No newline at end of file diff --git a/src/lib/SideBars/CharConfig.svelte b/src/lib/SideBars/CharConfig.svelte index dfaa8240..2a6c44ff 100644 --- a/src/lib/SideBars/CharConfig.svelte +++ b/src/lib/SideBars/CharConfig.svelte @@ -29,6 +29,7 @@ import TriggerList from "./Scripts/TriggerList.svelte"; import CheckInput from "../UI/GUI/CheckInput.svelte"; import { updateInlayScreen } from "src/ts/process/inlayScreen"; + import { registerOnnxModel } from "src/ts/process/embedding/transformers"; let subMenu = 0 @@ -626,6 +627,19 @@ Language {/if} + {#if currentChar.data.ttsMode === 'vits'} + {#if currentChar.data.vits} + {currentChar.data.vits.name ?? 'Unnamed VitsModel'} + {:else} + No Model + {/if} + + {/if} {#if currentChar.data.ttsMode}
diff --git a/src/ts/process/embedding/transformers.ts b/src/ts/process/embedding/transformers.ts index 6cd2e49e..400e6c2b 100644 --- a/src/ts/process/embedding/transformers.ts +++ b/src/ts/process/embedding/transformers.ts @@ -1,6 +1,11 @@ -import {env, AutoTokenizer, pipeline, VitsModel, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers'; +import {env, AutoTokenizer, pipeline, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers'; +import { unzip } from 'fflate'; +import { loadAsset, saveAsset } from 'src/ts/storage/globalApi'; +import { selectSingleFile } from 'src/ts/util'; +import { v4 } from 'uuid'; -env.localModelPath = "https://sv.risuai.xyz/transformers/" +env.localModelPath = "/transformers/" +env.remoteHost = "https://sv.risuai.xyz/transformers/" export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => { let text = baseText @@ -61,11 +66,49 @@ export const runEmbedding = async (text: string):Promise => { let synthesizer:TextToAudioPipeline = null let lastSynth:string = null -export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng') => { + +export interface OnnxModelFiles { + files: {[key:string]:string}, + id: string, + name?: string +} + +export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'Xenova/mms-tts-eng') => { const {WaveFile} = await import('wavefile') - if((!synthesizer) || (lastSynth !== model)){ - lastSynth = model - synthesizer = await pipeline('text-to-speech', model); + if(modelData === null){ + return + } + if(typeof modelData === 'string'){ + if((!synthesizer) || (lastSynth !== modelData)){ + lastSynth = modelData + synthesizer = await pipeline('text-to-speech', modelData); + } + } + else{ + if((!synthesizer) || (lastSynth !== modelData.id)){ + const files = modelData.files + const keys = Object.keys(files) + for(const key of keys){ + const hasCache:boolean = (await (await fetch("/sw/check/", { + headers: { + 'x-register-url': encodeURIComponent(key) + } + })).json()).able + + if(!hasCache){ + await fetch("/sw/register/", { + method: "POST", + body: await loadAsset(files[key]), + headers: { + 'x-register-url': encodeURIComponent(key), + 'x-no-content-type': 'true' + } + }) + } + } + lastSynth = modelData.id + synthesizer = await pipeline('text-to-speech', modelData.id); + } } let out = await synthesizer(text, {}); const wav = new WaveFile(); @@ -77,4 +120,52 @@ export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng') sourceNode.connect(audioContext.destination); sourceNode.start(); }); -} \ No newline at end of file +} + + +export const registerOnnxModel = async ():Promise => { + const id = v4().replace(/-/g, '') + + const modelFile = await selectSingleFile(['zip']) + + if(!modelFile){ + return + } + + const unziped = await new Promise((res, rej) => {unzip(modelFile.data, { + filter: (file) => { + return file.name.endsWith('.onnx') || file.size < 10_000_000 || file.name.includes('.git') + } + }, (err, unzipped) => { + if(err){ + rej(err) + } + else{ + res(unzipped) + } + })}) + + console.log(unziped) + + let fileIdMapped:{[key:string]:string} = {} + + const keys = Object.keys(unziped) + for(let i = 0; i < keys.length; i++){ + const key = keys[i] + const file = unziped[key] + const fid = await saveAsset(file) + let url = key + if(url.startsWith('/')){ + url = url.substring(1) + } + url = '/transformers/' + id +'/' + url + fileIdMapped[url] = fid + } + + return { + files: fileIdMapped, + name: modelFile.name, + id: id, + } + +} diff --git a/src/ts/process/tts.ts b/src/ts/process/tts.ts index 8c29f28a..d117636a 100644 --- a/src/ts/process/tts.ts +++ b/src/ts/process/tts.ts @@ -5,211 +5,171 @@ import { runTranslator, translateVox } from "../translator/translator"; import { globalFetch } from "../storage/globalApi"; import { language } from "src/lang"; import { getCurrentCharacter, sleep } from "../util"; -import { runVITS } from "./embedding/transformers"; +import { registerOnnxModel, runVITS } from "./embedding/transformers"; let sourceNode:AudioBufferSourceNode = null export async function sayTTS(character:character,text:string) { - if(!character){ - const v = getCurrentCharacter() - if(v.type === 'group'){ - return - } - character = v - } - - let db = get(DataBase) - text = text.replace(/\*/g,'') - - if(character.ttsReadOnlyQuoted){ - const matches = text.match(/"(.*?)"/g) - if(matches && matches.length > 0){ - text = matches.map(match => match.slice(1, -1)).join(""); - } - else{ - text = '' - } - } - - switch(character.ttsMode){ - case "webspeech":{ - if(speechSynthesis && SpeechSynthesisUtterance){ - const utterThis = new SpeechSynthesisUtterance(text); - const voices = speechSynthesis.getVoices(); - let voiceIndex = 0 - for(let i=0;i= 200 && da.status < 300){ - const audioBuffer = await audioContext.decodeAudioData(await da.arrayBuffer()) - sourceNode = audioContext.createBufferSource(); - sourceNode.buffer = audioBuffer; - sourceNode.connect(audioContext.destination); - sourceNode.start(); + + let db = get(DataBase) + text = text.replace(/\*/g,'') + + if(character.ttsReadOnlyQuoted){ + const matches = text.match(/"(.*?)"/g) + if(matches && matches.length > 0){ + text = matches.map(match => match.slice(1, -1)).join(""); } else{ - alertError(await da.text()) + text = '' } - break } - case "VOICEVOX": { - const jpText = await translateVox(text) - const audioContext = new AudioContext(); - const query = await fetch(`${db.voicevoxUrl}/audio_query?text=${jpText}&speaker=${character.ttsSpeech}`, { - method: 'POST', - headers: { "Content-Type": "application/json"}, - }) - if (query.status == 200){ - const queryJson = await query.json(); - const bodyData = { - accent_phrases: queryJson.accent_phrases, - speedScale: character.voicevoxConfig.SPEED_SCALE, - pitchScale: character.voicevoxConfig.PITCH_SCALE, - volumeScale: character.voicevoxConfig.VOLUME_SCALE, - intonationScale: character.voicevoxConfig.INTONATION_SCALE, - prePhonemeLength: queryJson.prePhonemeLength, - postPhonemeLength: queryJson.postPhonemeLength, - outputSamplingRate: queryJson.outputSamplingRate, - outputStereo: queryJson.outputStereo, - kana: queryJson.kana, + + switch(character.ttsMode){ + case "webspeech":{ + if(speechSynthesis && SpeechSynthesisUtterance){ + const utterThis = new SpeechSynthesisUtterance(text); + const voices = speechSynthesis.getVoices(); + let voiceIndex = 0 + for(let i=0;i= 200 && da.status < 300){ + const audioBuffer = await audioContext.decodeAudioData(await da.arrayBuffer()) + sourceNode = audioContext.createBufferSource(); + sourceNode.buffer = audioBuffer; + sourceNode.connect(audioContext.destination); + sourceNode.start(); + } + else{ + alertError(await da.text()) + } + break + } + case "VOICEVOX": { + const jpText = await translateVox(text) + const audioContext = new AudioContext(); + const query = await fetch(`${db.voicevoxUrl}/audio_query?text=${jpText}&speaker=${character.ttsSpeech}`, { method: 'POST', headers: { "Content-Type": "application/json"}, - body: JSON.stringify(bodyData), }) - if (getVoice.status == 200 && getVoice.headers.get('content-type') === 'audio/wav'){ - const audioBuffer = await audioContext.decodeAudioData(await getVoice.arrayBuffer()) - sourceNode = audioContext.createBufferSource(); - sourceNode.buffer = audioBuffer; - sourceNode.connect(audioContext.destination); - sourceNode.start(); - } - } - break - } - case 'openai':{ - const key = db.openAIKey - const res = await globalFetch('https://api.openai.com/v1/audio/speech', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': 'Bearer ' + key, - }, - body: { - model: 'tts-1', - input: text, - voice: character.oaiVoice, - - }, - rawResponse: true, - }) - const dat = res.data - - if(res.ok){ - try { - const audio = Buffer.from(dat).buffer - const audioContext = new AudioContext(); - const audioBuffer = await audioContext.decodeAudioData(audio) - sourceNode = audioContext.createBufferSource(); - sourceNode.buffer = audioBuffer; - sourceNode.connect(audioContext.destination); - sourceNode.start(); - } catch (error) { - alertError(language.errors.httpError + `${error}`) - } - } - else{ - if(dat.error && dat.error.message){ - alertError((language.errors.httpError + `${dat.error.message}`)) - } - else{ - alertError((language.errors.httpError + `${Buffer.from(res.data).toString()}`)) - } - } - break; - - } - case 'novelai': { - const audioContext = new AudioContext(); - if(text === ''){ - break; - } - const encodedText = encodeURIComponent(text); - const encodedSeed = encodeURIComponent(character.naittsConfig.voice); - - const url = `https://api.novelai.net/ai/generate-voice?text=${encodedText}&voice=-1&seed=${encodedSeed}&opus=false&version=${character.naittsConfig.version}`; - - const response = await globalFetch(url, { - method: 'GET', - headers: { - "Authorization": "Bearer " + db.NAIApiKey, - }, - rawResponse: true - }); - - if (response.ok) { - const audioBuffer = response.data.buffer; - audioContext.decodeAudioData(audioBuffer, (decodedData) => { - const sourceNode = audioContext.createBufferSource(); - sourceNode.buffer = decodedData; - sourceNode.connect(audioContext.destination); - sourceNode.start(); - }); - } else { - alertError("Error fetching or decoding audio data"); - } - break; - } - case 'huggingface': { - while(true){ - if(character.hfTTS.language !== 'en'){ - text = await runTranslator(text, false, 'en', character.hfTTS.language) - } - const audioContext = new AudioContext(); - const response = await fetch(`https://api-inference.huggingface.co/models/${character.hfTTS.model}`, { - method: 'POST', - headers: { - "Authorization": "Bearer " + db.huggingfaceKey, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - inputs: text, + if (query.status == 200){ + const queryJson = await query.json(); + const bodyData = { + accent_phrases: queryJson.accent_phrases, + speedScale: character.voicevoxConfig.SPEED_SCALE, + pitchScale: character.voicevoxConfig.PITCH_SCALE, + volumeScale: character.voicevoxConfig.VOLUME_SCALE, + intonationScale: character.voicevoxConfig.INTONATION_SCALE, + prePhonemeLength: queryJson.prePhonemeLength, + postPhonemeLength: queryJson.postPhonemeLength, + outputSamplingRate: queryJson.outputSamplingRate, + outputStereo: queryJson.outputStereo, + kana: queryJson.kana, + } + const getVoice = await fetch(`${db.voicevoxUrl}/synthesis?speaker=${character.ttsSpeech}`, { + method: 'POST', + headers: { "Content-Type": "application/json"}, + body: JSON.stringify(bodyData), }) - }); - - if(response.status === 503 && response.headers.get('content-type') === 'application/json'){ - const json = await response.json() - if(json.estimated_time){ - await sleep(json.estimated_time * 1000) - continue + if (getVoice.status == 200 && getVoice.headers.get('content-type') === 'audio/wav'){ + const audioBuffer = await audioContext.decodeAudioData(await getVoice.arrayBuffer()) + sourceNode = audioContext.createBufferSource(); + sourceNode.buffer = audioBuffer; + sourceNode.connect(audioContext.destination); + sourceNode.start(); } } - else if(response.status >= 400){ - alertError(language.errors.httpError + `${await response.text()}`) - return + break + } + case 'openai':{ + const key = db.openAIKey + const res = await globalFetch('https://api.openai.com/v1/audio/speech', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + key, + }, + body: { + model: 'tts-1', + input: text, + voice: character.oaiVoice, + + }, + rawResponse: true, + }) + const dat = res.data + + if(res.ok){ + try { + const audio = Buffer.from(dat).buffer + const audioContext = new AudioContext(); + const audioBuffer = await audioContext.decodeAudioData(audio) + sourceNode = audioContext.createBufferSource(); + sourceNode.buffer = audioBuffer; + sourceNode.connect(audioContext.destination); + sourceNode.start(); + } catch (error) { + alertError(language.errors.httpError + `${error}`) + } } - else if (response.status === 200) { - const audioBuffer = await response.arrayBuffer(); + else{ + if(dat.error && dat.error.message){ + alertError((language.errors.httpError + `${dat.error.message}`)) + } + else{ + alertError((language.errors.httpError + `${Buffer.from(res.data).toString()}`)) + } + } + break; + + } + case 'novelai': { + const audioContext = new AudioContext(); + if(text === ''){ + break; + } + const encodedText = encodeURIComponent(text); + const encodedSeed = encodeURIComponent(character.naittsConfig.voice); + + const url = `https://api.novelai.net/ai/generate-voice?text=${encodedText}&voice=-1&seed=${encodedSeed}&opus=false&version=${character.naittsConfig.version}`; + + const response = await globalFetch(url, { + method: 'GET', + headers: { + "Authorization": "Bearer " + db.NAIApiKey, + }, + rawResponse: true + }); + + if (response.ok) { + const audioBuffer = response.data.buffer; audioContext.decodeAudioData(audioBuffer, (decodedData) => { const sourceNode = audioContext.createBufferSource(); sourceNode.buffer = decodedData; @@ -219,12 +179,56 @@ export async function sayTTS(character:character,text:string) { } else { alertError("Error fetching or decoding audio data"); } - return + break; } - } - case 'vits':{ - await runVITS(text) - } + case 'huggingface': { + while(true){ + if(character.hfTTS.language !== 'en'){ + text = await runTranslator(text, false, 'en', character.hfTTS.language) + } + const audioContext = new AudioContext(); + const response = await fetch(`https://api-inference.huggingface.co/models/${character.hfTTS.model}`, { + method: 'POST', + headers: { + "Authorization": "Bearer " + db.huggingfaceKey, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + inputs: text, + }) + }); + + if(response.status === 503 && response.headers.get('content-type') === 'application/json'){ + const json = await response.json() + if(json.estimated_time){ + await sleep(json.estimated_time * 1000) + continue + } + } + else if(response.status >= 400){ + alertError(language.errors.httpError + `${await response.text()}`) + return + } + else if (response.status === 200) { + const audioBuffer = await response.arrayBuffer(); + audioContext.decodeAudioData(audioBuffer, (decodedData) => { + const sourceNode = audioContext.createBufferSource(); + sourceNode.buffer = decodedData; + sourceNode.connect(audioContext.destination); + sourceNode.start(); + }); + } else { + alertError("Error fetching or decoding audio data"); + } + return + } + } + case 'vits':{ + await runVITS(text, character.vits) + } + } + } catch (error) { + alertError(`TTS Error: ${error}`) } } diff --git a/src/ts/storage/database.ts b/src/ts/storage/database.ts index 6bb4a5b8..7fb86247 100644 --- a/src/ts/storage/database.ts +++ b/src/ts/storage/database.ts @@ -695,7 +695,8 @@ export interface character{ hfTTS?: { model: string language: string - } + }, + vits?: OnnxModelFiles } @@ -1115,6 +1116,7 @@ export function setPreset(db:Database, newPres: botPreset){ import { encode as encodeMsgpack, decode as decodeMsgpack } from "msgpackr"; import * as fflate from "fflate"; +import type { OnnxModelFiles } from '../process/embedding/transformers'; export async function downloadPreset(id:number){ saveCurrentPreset() diff --git a/src/ts/storage/globalApi.ts b/src/ts/storage/globalApi.ts index b58fadc3..9f9b9a82 100644 --- a/src/ts/storage/globalApi.ts +++ b/src/ts/storage/globalApi.ts @@ -11,7 +11,7 @@ import { checkOldDomain, checkUpdate } from "../update"; import { botMakerMode, selectedCharID } from "../stores"; import { Body, ResponseType, fetch as TauriFetch } from "@tauri-apps/api/http"; import { loadPlugins } from "../plugins/plugins"; -import { alertConfirm, alertError } from "../alert"; +import { alertConfirm, alertError, alertNormal } from "../alert"; import { checkDriverInit, syncDrive } from "../drive/drive"; import { hasher } from "../parser"; import { characterURLImport, hubURL } from "../characterCards"; @@ -231,6 +231,15 @@ export async function saveAsset(data:Uint8Array, customId:string = '', fileName: } } +export async function loadAsset(id:string){ + if(isTauri){ + return await readBinaryFile(id,{dir: BaseDirectory.AppData}) + } + else{ + return await forageStorage.getItem(id) as Uint8Array + } +} + let lastSave = '' export async function saveDb(){ @@ -369,6 +378,7 @@ export async function loadData() { throw "Your save file is corrupted" } } + await registerSw() await checkUpdate() await changeFullscreen() @@ -432,15 +442,7 @@ export async function loadData() { } if(navigator.serviceWorker && (!Capacitor.isNativePlatform())){ usingSw = true - await navigator.serviceWorker.register("/sw.js", { - scope: "/" - }); - - await sleep(100) - const da = await fetch('/sw/init') - if(!(da.status >= 200 && da.status < 300)){ - location.reload() - } + await registerSw() } else{ usingSw = false @@ -792,6 +794,20 @@ export async function globalFetch(url:string, arg:{ } } +async function registerSw() { + await navigator.serviceWorker.register("/sw.js", { + scope: "/" + }); + await sleep(100) + const da = await fetch('/sw/init') + if(!(da.status >= 200 && da.status < 300)){ + location.reload() + } + else{ + + } +} + const re = /\\/g function getBasename(data:string){ const splited = data.replace(re, '/').split('/') @@ -833,6 +849,13 @@ export function getUnpargeables(db:Database, uptype:'basename'|'pure' = 'basenam addUnparge(em[1]) } } + if(cha.vits){ + const keys = Object.keys(cha.vits.files) + for(const key of keys){ + const vit = cha.vits.files[key] + addUnparge(vit) + } + } } } @@ -1044,7 +1067,7 @@ async function pargeChunks(){ const assets = await readDir('assets', {dir: BaseDirectory.AppData}) for(const asset of assets){ const n = getBasename(asset.name) - if(unpargeable.includes(n) || (!n.endsWith('png'))){ + if(unpargeable.includes(n)){ } else{ await removeFile(asset.path) @@ -1054,8 +1077,11 @@ async function pargeChunks(){ else{ const indexes = await forageStorage.keys() for(const asset of indexes){ + if(!asset.startsWith('assets/')){ + continue + } const n = getBasename(asset) - if(unpargeable.includes(n) || (!asset.endsWith(".png"))){ + if(unpargeable.includes(n)){ } else{ await forageStorage.removeItem(asset)