From 64ee71e2c6ca2b68fbf593f357653bca8c672889 Mon Sep 17 00:00:00 2001 From: kwaroran Date: Thu, 17 Aug 2023 14:49:06 +0900 Subject: [PATCH] [feat] add local code --- src-tauri/src/main.rs | 105 ++++++++++++++- src/lib/ChatScreens/DefaultChatScreen.svelte | 2 +- src/lib/ChatScreens/Suggestion.svelte | 2 +- src/lib/Setting/Pages/BotSettings.svelte | 2 +- src/lib/UI/ModelList.svelte | 16 ++- src/ts/alert.ts | 7 + src/ts/process/models/local.ts | 134 +++++++++++++++++++ 7 files changed, 263 insertions(+), 5 deletions(-) create mode 100644 src/ts/process/models/local.ts diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 93fb2a6b..fbb7ec50 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -122,9 +122,112 @@ fn check_auth(fpath: String, auth: String) -> bool{ } +use std::process::Command; + +#[tauri::command] +fn check_requirements_local() -> String{ + let mut py = Command::new("python"); + let output = py.arg("--version").output(); + match output { + Ok(o) => { + let res = String::from_utf8(o.stdout).unwrap(); + if !res.starts_with("Python ") { + return "Python is not installed".to_string() + } + println!("{}", res); + }, + Err(e) => { + println!("{}", e); + return "Python is not installed, or not loadable".to_string() + } + } + + let mut git = Command::new("git"); + let output = git.arg("--version").output(); + match output { + Ok(o) => { + let res = String::from_utf8(o.stdout).unwrap(); + if !res.starts_with("git version ") { + return "Git is not installed".to_string() + } + println!("{}", res); + }, + Err(e) => { + println!("{}", e); + return "Git is not installed, or not loadable".to_string() + } + } + + return "success".to_string() +} + +#[tauri::command] +fn run_server_local(){ + let app_base_path = tauri::api::path::data_dir().unwrap().join("co.aiclient.risu"); + + //check app base path exists + if !app_base_path.exists() { + std::fs::create_dir_all(&app_base_path).unwrap(); + } + + let server_path = app_base_path.clone().join("local_server"); + + //check server path exists + if !&server_path.exists() { + //git clone server + let mut git = Command::new("git"); + let output = git + .current_dir(&app_base_path.clone()) + .arg("clone") + .arg("https://github.com/kwaroran/risu-exllama-connector.git") + .output(); + match output { + Ok(o) => { + let res = String::from_utf8(o.stdout).unwrap(); + println!("output: {}", res); + }, + Err(e) => { + println!("{}", e); + return + } + } + + println!("cloned"); + + let git_cloned_path = app_base_path.clone().join("risu-exllama-connector"); + + println!("git_cloned_path: {}", git_cloned_path.display()); + //rename folder to local_server + std::fs::rename(git_cloned_path, server_path.clone()).unwrap(); + } + + + //check os is windows + if cfg!(target_os = "windows") { + println!("windows runner"); + let command_location = &server_path.clone().join("run.cmd"); + let mut server = Command::new(command_location); + let mut _child = server.current_dir(server_path).spawn().expect("failed to execute process"); + } + else{ + println!("linux/mac runner"); + let command_location = &server_path.clone().join("run.sh"); + let mut server = Command::new(command_location); + let mut _child = server.current_dir(server_path).spawn().expect("failed to execute process"); + } + return + +} + fn main() { tauri::Builder::default() - .invoke_handler(tauri::generate_handler![greet, native_request, check_auth]) + .invoke_handler(tauri::generate_handler![ + greet, + native_request, + check_auth, + check_requirements_local, + run_server_local + ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); } diff --git a/src/lib/ChatScreens/DefaultChatScreen.svelte b/src/lib/ChatScreens/DefaultChatScreen.svelte index 97710c49..2d0f482e 100644 --- a/src/lib/ChatScreens/DefaultChatScreen.svelte +++ b/src/lib/ChatScreens/DefaultChatScreen.svelte @@ -404,7 +404,7 @@ {#if $DataBase.useAutoSuggestions} messageInput=( - ($DataBase.subModel === "textgen_webui" || $DataBase.subModel === "mancer") && $DataBase.autoSuggestClean + ($DataBase.subModel === "textgen_webui" || $DataBase.subModel === "mancer" || $DataBase.subModel.startsWith('local_')) && $DataBase.autoSuggestClean ? msg.replace(/ +\(.+?\) *$| - [^"'*]*?$/, '') : msg )} {send}/> diff --git a/src/lib/ChatScreens/Suggestion.svelte b/src/lib/ChatScreens/Suggestion.svelte index 2a0f6fa4..5c735401 100644 --- a/src/lib/ChatScreens/Suggestion.svelte +++ b/src/lib/ChatScreens/Suggestion.svelte @@ -75,7 +75,7 @@ } ] - if($DataBase.subModel === "textgen_webui" || $DataBase.subModel === 'mancer'){ + if($DataBase.subModel === "textgen_webui" || $DataBase.subModel === 'mancer' || $DataBase.subModel.startsWith('local_')){ promptbody = [ { role: 'system', diff --git a/src/lib/Setting/Pages/BotSettings.svelte b/src/lib/Setting/Pages/BotSettings.svelte index 646a0b12..1fcbe0e7 100644 --- a/src/lib/Setting/Pages/BotSettings.svelte +++ b/src/lib/Setting/Pages/BotSettings.svelte @@ -265,7 +265,7 @@ {/if} {($DataBase.temperature / 100).toFixed(2)} -{#if $DataBase.aiModel === 'textgen_webui' || $DataBase.subModel === 'mancer'} +{#if $DataBase.aiModel === 'textgen_webui' || $DataBase.subModel === 'mancer' || $DataBase.subModel.startsWith('local_')} Repetition Penalty {($DataBase.ooba.repetition_penalty).toFixed(2)} diff --git a/src/lib/UI/ModelList.svelte b/src/lib/UI/ModelList.svelte index ba24ff93..e8ed0cab 100644 --- a/src/lib/UI/ModelList.svelte +++ b/src/lib/UI/ModelList.svelte @@ -4,6 +4,8 @@ import Arcodion from "./Arcodion.svelte"; import { language } from "src/lang"; import { isNodeServer, isTauri } from "src/ts/storage/globalApi"; + import { checkLocalModel } from "src/ts/process/models/local"; + import { alertError } from "src/ts/alert"; let openAdv = true export let value = "" @@ -32,6 +34,8 @@ return "GPT-4 0613" case "gpt4_32k_0613": return "GPT-4 32k 0613" + case 'local_gptq': + return 'Local Model GPTQ' case "palm2": return "PaLM2" case "textgen_webui": @@ -111,10 +115,20 @@ {/if} - + + diff --git a/src/ts/alert.ts b/src/ts/alert.ts index 36b6f475..c66536d7 100644 --- a/src/ts/alert.ts +++ b/src/ts/alert.ts @@ -87,6 +87,13 @@ export function alertWait(msg:string){ } +export function alertClear(){ + alertStore.set({ + 'type': 'none', + 'msg': '' + }) +} + export async function alertSelectChar(){ alertStore.set({ 'type': 'selectChar', diff --git a/src/ts/process/models/local.ts b/src/ts/process/models/local.ts new file mode 100644 index 00000000..b4a4226a --- /dev/null +++ b/src/ts/process/models/local.ts @@ -0,0 +1,134 @@ +import { invoke } from "@tauri-apps/api/tauri"; +import { globalFetch } from "src/ts/storage/globalApi"; +import { sleep } from "src/ts/util"; +import path from "@tauri-apps/api/path"; +import { exists } from "@tauri-apps/api/fs"; +import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert"; +import { get } from "svelte/store"; +import { DataBase } from "src/ts/storage/database"; +let serverRunning = false; + +export function checkLocalModel():Promise{ + return invoke("check_requirements_local") +} + +export async function startLocalModelServer(){ + if(!serverRunning){ + serverRunning = true + await invoke("run_server_local") + } + return +} + +export async function checkLocalServerInstalled() { + const p = await path.join(await path.appDataDir(), 'local_server') + return await exists(p) +} + +export interface LocalLoaderItem { + dir: string; + max_seq_len?: number ; + max_input_len?: number ; + max_attention_size?: number ; + compress_pos_emb?: number ; + alpha_value?: number ; + gpu_peer_fixed?: boolean ; + auto_map?: boolean ; + use_flash_attn_2?: boolean ; + matmul_recons_thd?: number ; + fused_mlp_thd?: number ; + sdp_thd?: number ; + fused_attn?: boolean ; + matmul_fused_remap?: boolean ; + rmsnorm_no_half2?: boolean ; + rope_no_half2?: boolean ; + matmul_no_half2?: boolean ; + silu_no_half2?: boolean ; + concurrent_streams?: boolean ; +} +// class GeneratorItem(BaseModel): +// temperature: Union[float, None] +// top_k: Union[int, None] +// top_p: Union[float, None] +// min_p: Union[float, None] +// typical: Union[float, None] +// token_repetition_penalty_max: Union[float, None] +// token_repetition_penalty_sustain: Union[int, None] +// token_repetition_penalty_decay: Union[int, None] +// beams: Union[int, None] +// beam_length: Union[int, None] +// disallowed_tokens: Union[list[int], None] +// prompt: str +// max_new_tokens: Union[int, None] +interface LocalGeneratorItem { + temperature?: number; + top_k?: number; + top_p?: number; + min_p?: number; + typical?: number; + token_repetition_penalty_max?: number; + token_repetition_penalty_sustain?: number; + token_repetition_penalty_decay?: number; + beams?: number; + beam_length?: number; + disallowed_tokens?: number[]; + prompt: string; + max_new_tokens?: number; +} + + +export async function loadExllamaFull(){ + + try { + await startLocalModelServer() + if(await checkLocalServerInstalled()){ + alertWait("Loading exllama") + } + else{ + alertWait("Installing & Loading exllama, this would take a while for the first time") + } + while(true){ + //check is server is running by fetching the status + try { + const res = await globalFetch("http://localhost:7239/") + if(res.ok){ + break + } + } catch (error) {} + await sleep(1000) + } + + const body:LocalLoaderItem = { + dir: "exllama", + } + + alertWait("Loading Local Model") + const res = await globalFetch("http://localhost:7239/load/", { + body: body + }) + alertClear() + } catch (error) { + alertError("Error when loading Exllama: " + error) + } + +} + + +export async function runLocalModel(prompt:string){ + const db = get(DataBase) + + const body:LocalGeneratorItem = { + prompt: prompt, + temperature: db.temperature, + top_k: db.ooba.top_k, + top_p: db.ooba.top_p, + typical: db.ooba.typical_p, + max_new_tokens: db.maxResponse + } + + const gen = await globalFetch("http://localhost:7239/generate/", { + body: body + }) + + console.log(gen) +} \ No newline at end of file