Add model selection for VitsModel

This commit is contained in:
kwaroran
2024-01-06 06:45:18 +09:00
parent 66c6511684
commit 7344e566f4
7 changed files with 366 additions and 220 deletions

View File

@@ -7,7 +7,13 @@ self.addEventListener('fetch', (event) => {
try { try {
switch (path[2]){ switch (path[2]){
case "check":{ case "check":{
event.respondWith(checkCache(url)) let targetUrl = url
const headers = event.request.headers
const headerUrl = headers.get('x-register-url')
if(headerUrl){
targetUrl.pathname = decodeURIComponent(headerUrl)
}
event.respondWith(checkCache(targetUrl))
break break
} }
case "img": { case "img": {
@@ -15,20 +21,20 @@ self.addEventListener('fetch', (event) => {
break break
} }
case "register": { case "register": {
let targerUrl = url let targetUrl = url
const headers = event.request.headers const headers = event.request.headers
const headerUrl = headers.get('x-register-url') const headerUrl = headers.get('x-register-url')
if(headerUrl){ if(headerUrl){
targerUrl = new URL(headerUrl) targetUrl.pathname = decodeURIComponent(headerUrl)
} }
const noContentType = headers.get('x-no-content-type') === 'true' const noContentType = headers.get('x-no-content-type') === 'true'
event.respondWith( event.respondWith(
registerCache(targerUrl, event.request.arrayBuffer(), noContentType) registerCache(targetUrl, event.request.arrayBuffer(), noContentType)
) )
break break
} }
case "init":{ case "init":{
event.respondWith(new Response("true")) event.respondWith(new Response("v2"))
} }
default: { default: {
event.respondWith(new Response( event.respondWith(new Response(
@@ -74,9 +80,11 @@ async function check(){
async function registerCache(urlr, buffer, noContentType = false){ async function registerCache(urlr, buffer, noContentType = false){
const cache = await caches.open('risuCache') const cache = await caches.open('risuCache')
const url = new URL(urlr) const url = new URL(urlr)
if(!noContentType){
let path = url.pathname.split('/') let path = url.pathname.split('/')
path[2] = 'img' path[2] = 'img'
url.pathname = path.join('/') url.pathname = path.join('/')
}
const buf = new Uint8Array(await buffer) const buf = new Uint8Array(await buffer)
let headers = { let headers = {
"cache-control": "max-age=604800", "cache-control": "max-age=604800",

View File

@@ -484,4 +484,5 @@ export const languageEnglish = {
chatAsOriginalOnSystem: "Send as original role", chatAsOriginalOnSystem: "Send as original role",
exportAsDataset: "Export Save as Dataset", exportAsDataset: "Export Save as Dataset",
editTranslationDisplay: "Edit Translation Display", editTranslationDisplay: "Edit Translation Display",
selectModel: "Select Model",
} }

View File

@@ -29,6 +29,7 @@
import TriggerList from "./Scripts/TriggerList.svelte"; import TriggerList from "./Scripts/TriggerList.svelte";
import CheckInput from "../UI/GUI/CheckInput.svelte"; import CheckInput from "../UI/GUI/CheckInput.svelte";
import { updateInlayScreen } from "src/ts/process/inlayScreen"; import { updateInlayScreen } from "src/ts/process/inlayScreen";
import { registerOnnxModel } from "src/ts/process/embedding/transformers";
let subMenu = 0 let subMenu = 0
@@ -626,6 +627,19 @@
<span class="text-textcolor">Language</span> <span class="text-textcolor">Language</span>
<TextInput additionalClass="mb-4 mt-2" bind:value={currentChar.data.hfTTS.language} placeholder="en" /> <TextInput additionalClass="mb-4 mt-2" bind:value={currentChar.data.hfTTS.language} placeholder="en" />
{/if} {/if}
{#if currentChar.data.ttsMode === 'vits'}
{#if currentChar.data.vits}
<span class="text-textcolor">{currentChar.data.vits.name ?? 'Unnamed VitsModel'}</span>
{:else}
<span class="text-textcolor">No Model</span>
{/if}
<Button on:click={async () => {
const model = await registerOnnxModel()
if(model && currentChar.type === 'character'){
currentChar.data.vits = model
}
}}>{language.selectModel}</Button>
{/if}
{#if currentChar.data.ttsMode} {#if currentChar.data.ttsMode}
<div class="flex items-center mt-2"> <div class="flex items-center mt-2">
<Check bind:check={currentChar.data.ttsReadOnlyQuoted} name={language.ttsReadOnlyQuoted}/> <Check bind:check={currentChar.data.ttsReadOnlyQuoted} name={language.ttsReadOnlyQuoted}/>

View File

@@ -1,6 +1,11 @@
import {env, AutoTokenizer, pipeline, VitsModel, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers'; import {env, AutoTokenizer, pipeline, type SummarizationOutput, type TextGenerationConfig, type TextGenerationOutput, FeatureExtractionPipeline, TextToAudioPipeline } from '@xenova/transformers';
import { unzip } from 'fflate';
import { loadAsset, saveAsset } from 'src/ts/storage/globalApi';
import { selectSingleFile } from 'src/ts/util';
import { v4 } from 'uuid';
env.localModelPath = "https://sv.risuai.xyz/transformers/" env.localModelPath = "/transformers/"
env.remoteHost = "https://sv.risuai.xyz/transformers/"
export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => { export const runTransformers = async (baseText:string, model:string,config:TextGenerationConfig = {}) => {
let text = baseText let text = baseText
@@ -61,11 +66,49 @@ export const runEmbedding = async (text: string):Promise<Float32Array> => {
let synthesizer:TextToAudioPipeline = null let synthesizer:TextToAudioPipeline = null
let lastSynth:string = null let lastSynth:string = null
export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng') => {
export interface OnnxModelFiles {
files: {[key:string]:string},
id: string,
name?: string
}
export const runVITS = async (text: string, modelData:string|OnnxModelFiles = 'Xenova/mms-tts-eng') => {
const {WaveFile} = await import('wavefile') const {WaveFile} = await import('wavefile')
if((!synthesizer) || (lastSynth !== model)){ if(modelData === null){
lastSynth = model return
synthesizer = await pipeline('text-to-speech', model); }
if(typeof modelData === 'string'){
if((!synthesizer) || (lastSynth !== modelData)){
lastSynth = modelData
synthesizer = await pipeline('text-to-speech', modelData);
}
}
else{
if((!synthesizer) || (lastSynth !== modelData.id)){
const files = modelData.files
const keys = Object.keys(files)
for(const key of keys){
const hasCache:boolean = (await (await fetch("/sw/check/", {
headers: {
'x-register-url': encodeURIComponent(key)
}
})).json()).able
if(!hasCache){
await fetch("/sw/register/", {
method: "POST",
body: await loadAsset(files[key]),
headers: {
'x-register-url': encodeURIComponent(key),
'x-no-content-type': 'true'
}
})
}
}
lastSynth = modelData.id
synthesizer = await pipeline('text-to-speech', modelData.id);
}
} }
let out = await synthesizer(text, {}); let out = await synthesizer(text, {});
const wav = new WaveFile(); const wav = new WaveFile();
@@ -78,3 +121,51 @@ export const runVITS = async (text: string, model:string = 'Xenova/mms-tts-eng')
sourceNode.start(); sourceNode.start();
}); });
} }
export const registerOnnxModel = async ():Promise<OnnxModelFiles> => {
const id = v4().replace(/-/g, '')
const modelFile = await selectSingleFile(['zip'])
if(!modelFile){
return
}
const unziped = await new Promise((res, rej) => {unzip(modelFile.data, {
filter: (file) => {
return file.name.endsWith('.onnx') || file.size < 10_000_000 || file.name.includes('.git')
}
}, (err, unzipped) => {
if(err){
rej(err)
}
else{
res(unzipped)
}
})})
console.log(unziped)
let fileIdMapped:{[key:string]:string} = {}
const keys = Object.keys(unziped)
for(let i = 0; i < keys.length; i++){
const key = keys[i]
const file = unziped[key]
const fid = await saveAsset(file)
let url = key
if(url.startsWith('/')){
url = url.substring(1)
}
url = '/transformers/' + id +'/' + url
fileIdMapped[url] = fid
}
return {
files: fileIdMapped,
name: modelFile.name,
id: id,
}
}

View File

@@ -5,11 +5,12 @@ import { runTranslator, translateVox } from "../translator/translator";
import { globalFetch } from "../storage/globalApi"; import { globalFetch } from "../storage/globalApi";
import { language } from "src/lang"; import { language } from "src/lang";
import { getCurrentCharacter, sleep } from "../util"; import { getCurrentCharacter, sleep } from "../util";
import { runVITS } from "./embedding/transformers"; import { registerOnnxModel, runVITS } from "./embedding/transformers";
let sourceNode:AudioBufferSourceNode = null let sourceNode:AudioBufferSourceNode = null
export async function sayTTS(character:character,text:string) { export async function sayTTS(character:character,text:string) {
try {
if(!character){ if(!character){
const v = getCurrentCharacter() const v = getCurrentCharacter()
if(v.type === 'group'){ if(v.type === 'group'){
@@ -223,9 +224,12 @@ export async function sayTTS(character:character,text:string) {
} }
} }
case 'vits':{ case 'vits':{
await runVITS(text) await runVITS(text, character.vits)
} }
} }
} catch (error) {
alertError(`TTS Error: ${error}`)
}
} }
export const oaiVoices = [ export const oaiVoices = [

View File

@@ -695,7 +695,8 @@ export interface character{
hfTTS?: { hfTTS?: {
model: string model: string
language: string language: string
} },
vits?: OnnxModelFiles
} }
@@ -1115,6 +1116,7 @@ export function setPreset(db:Database, newPres: botPreset){
import { encode as encodeMsgpack, decode as decodeMsgpack } from "msgpackr"; import { encode as encodeMsgpack, decode as decodeMsgpack } from "msgpackr";
import * as fflate from "fflate"; import * as fflate from "fflate";
import type { OnnxModelFiles } from '../process/embedding/transformers';
export async function downloadPreset(id:number){ export async function downloadPreset(id:number){
saveCurrentPreset() saveCurrentPreset()

View File

@@ -11,7 +11,7 @@ import { checkOldDomain, checkUpdate } from "../update";
import { botMakerMode, selectedCharID } from "../stores"; import { botMakerMode, selectedCharID } from "../stores";
import { Body, ResponseType, fetch as TauriFetch } from "@tauri-apps/api/http"; import { Body, ResponseType, fetch as TauriFetch } from "@tauri-apps/api/http";
import { loadPlugins } from "../plugins/plugins"; import { loadPlugins } from "../plugins/plugins";
import { alertConfirm, alertError } from "../alert"; import { alertConfirm, alertError, alertNormal } from "../alert";
import { checkDriverInit, syncDrive } from "../drive/drive"; import { checkDriverInit, syncDrive } from "../drive/drive";
import { hasher } from "../parser"; import { hasher } from "../parser";
import { characterURLImport, hubURL } from "../characterCards"; import { characterURLImport, hubURL } from "../characterCards";
@@ -231,6 +231,15 @@ export async function saveAsset(data:Uint8Array, customId:string = '', fileName:
} }
} }
export async function loadAsset(id:string){
if(isTauri){
return await readBinaryFile(id,{dir: BaseDirectory.AppData})
}
else{
return await forageStorage.getItem(id) as Uint8Array
}
}
let lastSave = '' let lastSave = ''
export async function saveDb(){ export async function saveDb(){
@@ -369,6 +378,7 @@ export async function loadData() {
throw "Your save file is corrupted" throw "Your save file is corrupted"
} }
} }
await registerSw()
await checkUpdate() await checkUpdate()
await changeFullscreen() await changeFullscreen()
@@ -432,15 +442,7 @@ export async function loadData() {
} }
if(navigator.serviceWorker && (!Capacitor.isNativePlatform())){ if(navigator.serviceWorker && (!Capacitor.isNativePlatform())){
usingSw = true usingSw = true
await navigator.serviceWorker.register("/sw.js", { await registerSw()
scope: "/"
});
await sleep(100)
const da = await fetch('/sw/init')
if(!(da.status >= 200 && da.status < 300)){
location.reload()
}
} }
else{ else{
usingSw = false usingSw = false
@@ -792,6 +794,20 @@ export async function globalFetch(url:string, arg:{
} }
} }
async function registerSw() {
await navigator.serviceWorker.register("/sw.js", {
scope: "/"
});
await sleep(100)
const da = await fetch('/sw/init')
if(!(da.status >= 200 && da.status < 300)){
location.reload()
}
else{
}
}
const re = /\\/g const re = /\\/g
function getBasename(data:string){ function getBasename(data:string){
const splited = data.replace(re, '/').split('/') const splited = data.replace(re, '/').split('/')
@@ -833,6 +849,13 @@ export function getUnpargeables(db:Database, uptype:'basename'|'pure' = 'basenam
addUnparge(em[1]) addUnparge(em[1])
} }
} }
if(cha.vits){
const keys = Object.keys(cha.vits.files)
for(const key of keys){
const vit = cha.vits.files[key]
addUnparge(vit)
}
}
} }
} }
@@ -1044,7 +1067,7 @@ async function pargeChunks(){
const assets = await readDir('assets', {dir: BaseDirectory.AppData}) const assets = await readDir('assets', {dir: BaseDirectory.AppData})
for(const asset of assets){ for(const asset of assets){
const n = getBasename(asset.name) const n = getBasename(asset.name)
if(unpargeable.includes(n) || (!n.endsWith('png'))){ if(unpargeable.includes(n)){
} }
else{ else{
await removeFile(asset.path) await removeFile(asset.path)
@@ -1054,8 +1077,11 @@ async function pargeChunks(){
else{ else{
const indexes = await forageStorage.keys() const indexes = await forageStorage.keys()
for(const asset of indexes){ for(const asset of indexes){
if(!asset.startsWith('assets/')){
continue
}
const n = getBasename(asset) const n = getBasename(asset)
if(unpargeable.includes(n) || (!asset.endsWith(".png"))){ if(unpargeable.includes(n)){
} }
else{ else{
await forageStorage.removeItem(asset) await forageStorage.removeItem(asset)