[feat] add local code

This commit is contained in:
kwaroran
2023-08-17 14:49:06 +09:00
parent 3f9e08bfb0
commit 64ee71e2c6
7 changed files with 263 additions and 5 deletions

View File

@@ -122,9 +122,112 @@ fn check_auth(fpath: String, auth: String) -> bool{
} }
use std::process::Command;
#[tauri::command]
fn check_requirements_local() -> String{
let mut py = Command::new("python");
let output = py.arg("--version").output();
match output {
Ok(o) => {
let res = String::from_utf8(o.stdout).unwrap();
if !res.starts_with("Python ") {
return "Python is not installed".to_string()
}
println!("{}", res);
},
Err(e) => {
println!("{}", e);
return "Python is not installed, or not loadable".to_string()
}
}
let mut git = Command::new("git");
let output = git.arg("--version").output();
match output {
Ok(o) => {
let res = String::from_utf8(o.stdout).unwrap();
if !res.starts_with("git version ") {
return "Git is not installed".to_string()
}
println!("{}", res);
},
Err(e) => {
println!("{}", e);
return "Git is not installed, or not loadable".to_string()
}
}
return "success".to_string()
}
#[tauri::command]
fn run_server_local(){
let app_base_path = tauri::api::path::data_dir().unwrap().join("co.aiclient.risu");
//check app base path exists
if !app_base_path.exists() {
std::fs::create_dir_all(&app_base_path).unwrap();
}
let server_path = app_base_path.clone().join("local_server");
//check server path exists
if !&server_path.exists() {
//git clone server
let mut git = Command::new("git");
let output = git
.current_dir(&app_base_path.clone())
.arg("clone")
.arg("https://github.com/kwaroran/risu-exllama-connector.git")
.output();
match output {
Ok(o) => {
let res = String::from_utf8(o.stdout).unwrap();
println!("output: {}", res);
},
Err(e) => {
println!("{}", e);
return
}
}
println!("cloned");
let git_cloned_path = app_base_path.clone().join("risu-exllama-connector");
println!("git_cloned_path: {}", git_cloned_path.display());
//rename folder to local_server
std::fs::rename(git_cloned_path, server_path.clone()).unwrap();
}
//check os is windows
if cfg!(target_os = "windows") {
println!("windows runner");
let command_location = &server_path.clone().join("run.cmd");
let mut server = Command::new(command_location);
let mut _child = server.current_dir(server_path).spawn().expect("failed to execute process");
}
else{
println!("linux/mac runner");
let command_location = &server_path.clone().join("run.sh");
let mut server = Command::new(command_location);
let mut _child = server.current_dir(server_path).spawn().expect("failed to execute process");
}
return
}
fn main() { fn main() {
tauri::Builder::default() tauri::Builder::default()
.invoke_handler(tauri::generate_handler![greet, native_request, check_auth]) .invoke_handler(tauri::generate_handler![
greet,
native_request,
check_auth,
check_requirements_local,
run_server_local
])
.run(tauri::generate_context!()) .run(tauri::generate_context!())
.expect("error while running tauri application"); .expect("error while running tauri application");
} }

View File

@@ -404,7 +404,7 @@
{#if $DataBase.useAutoSuggestions} {#if $DataBase.useAutoSuggestions}
<Suggestion messageInput={(msg)=>messageInput=( <Suggestion messageInput={(msg)=>messageInput=(
($DataBase.subModel === "textgen_webui" || $DataBase.subModel === "mancer") && $DataBase.autoSuggestClean ($DataBase.subModel === "textgen_webui" || $DataBase.subModel === "mancer" || $DataBase.subModel.startsWith('local_')) && $DataBase.autoSuggestClean
? msg.replace(/ +\(.+?\) *$| - [^"'*]*?$/, '') ? msg.replace(/ +\(.+?\) *$| - [^"'*]*?$/, '')
: msg : msg
)} {send}/> )} {send}/>

View File

@@ -75,7 +75,7 @@
} }
] ]
if($DataBase.subModel === "textgen_webui" || $DataBase.subModel === 'mancer'){ if($DataBase.subModel === "textgen_webui" || $DataBase.subModel === 'mancer' || $DataBase.subModel.startsWith('local_')){
promptbody = [ promptbody = [
{ {
role: 'system', role: 'system',

View File

@@ -265,7 +265,7 @@
{/if} {/if}
<span class="text-textcolor2 mb-6 text-sm">{($DataBase.temperature / 100).toFixed(2)}</span> <span class="text-textcolor2 mb-6 text-sm">{($DataBase.temperature / 100).toFixed(2)}</span>
{#if $DataBase.aiModel === 'textgen_webui' || $DataBase.subModel === 'mancer'} {#if $DataBase.aiModel === 'textgen_webui' || $DataBase.subModel === 'mancer' || $DataBase.subModel.startsWith('local_')}
<span class="text-textcolor">Repetition Penalty</span> <span class="text-textcolor">Repetition Penalty</span>
<SliderInput min={1} max={1.5} step={0.01} bind:value={$DataBase.ooba.repetition_penalty}/> <SliderInput min={1} max={1.5} step={0.01} bind:value={$DataBase.ooba.repetition_penalty}/>
<span class="text-textcolor2 mb-6 text-sm">{($DataBase.ooba.repetition_penalty).toFixed(2)}</span> <span class="text-textcolor2 mb-6 text-sm">{($DataBase.ooba.repetition_penalty).toFixed(2)}</span>

View File

@@ -4,6 +4,8 @@
import Arcodion from "./Arcodion.svelte"; import Arcodion from "./Arcodion.svelte";
import { language } from "src/lang"; import { language } from "src/lang";
import { isNodeServer, isTauri } from "src/ts/storage/globalApi"; import { isNodeServer, isTauri } from "src/ts/storage/globalApi";
import { checkLocalModel } from "src/ts/process/models/local";
import { alertError } from "src/ts/alert";
let openAdv = true let openAdv = true
export let value = "" export let value = ""
@@ -32,6 +34,8 @@
return "GPT-4 0613" return "GPT-4 0613"
case "gpt4_32k_0613": case "gpt4_32k_0613":
return "GPT-4 32k 0613" return "GPT-4 32k 0613"
case 'local_gptq':
return 'Local Model GPTQ'
case "palm2": case "palm2":
return "PaLM2" return "PaLM2"
case "textgen_webui": case "textgen_webui":
@@ -111,10 +115,20 @@
{/if} {/if}
</Arcodion> </Arcodion>
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('reverse_proxy')}}>Reverse Proxy</button> <button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('reverse_proxy')}}>Reverse Proxy</button>
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('openrouter')}}>OpenRouter</button> <!-- <button class="hover:bg-selected px-6 py-2 text-lg" on:click={async () => {
const res = (await checkLocalModel())
if(res === 'success'){
changeModel('local_gptq')
}
else{
alertError("python 3.10, cuda 11.7 and git must be installed to run it. " + res)
}
// changeModel('local_gptq')
}}>Local Model GPTQ</button> -->
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('textgen_webui')}}>Oobabooga WebUI</button> <button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('textgen_webui')}}>Oobabooga WebUI</button>
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('mancer')}}>Mancer</button> <button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('mancer')}}>Mancer</button>
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('palm2')}}>Google PaLM2</button> <button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('palm2')}}>Google PaLM2</button>
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('openrouter')}}>OpenRouter</button>
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('kobold')}}>Kobold</button> <button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('kobold')}}>Kobold</button>
<Arcodion name="Novellist"> <Arcodion name="Novellist">
<button class="p-2 hover:text-green-500" on:click={() => {changeModel('novellist')}}>SuperTrin</button> <button class="p-2 hover:text-green-500" on:click={() => {changeModel('novellist')}}>SuperTrin</button>

View File

@@ -87,6 +87,13 @@ export function alertWait(msg:string){
} }
export function alertClear(){
alertStore.set({
'type': 'none',
'msg': ''
})
}
export async function alertSelectChar(){ export async function alertSelectChar(){
alertStore.set({ alertStore.set({
'type': 'selectChar', 'type': 'selectChar',

View File

@@ -0,0 +1,134 @@
import { invoke } from "@tauri-apps/api/tauri";
import { globalFetch } from "src/ts/storage/globalApi";
import { sleep } from "src/ts/util";
import path from "@tauri-apps/api/path";
import { exists } from "@tauri-apps/api/fs";
import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert";
import { get } from "svelte/store";
import { DataBase } from "src/ts/storage/database";
let serverRunning = false;
export function checkLocalModel():Promise<string>{
return invoke("check_requirements_local")
}
export async function startLocalModelServer(){
if(!serverRunning){
serverRunning = true
await invoke("run_server_local")
}
return
}
export async function checkLocalServerInstalled() {
const p = await path.join(await path.appDataDir(), 'local_server')
return await exists(p)
}
export interface LocalLoaderItem {
dir: string;
max_seq_len?: number ;
max_input_len?: number ;
max_attention_size?: number ;
compress_pos_emb?: number ;
alpha_value?: number ;
gpu_peer_fixed?: boolean ;
auto_map?: boolean ;
use_flash_attn_2?: boolean ;
matmul_recons_thd?: number ;
fused_mlp_thd?: number ;
sdp_thd?: number ;
fused_attn?: boolean ;
matmul_fused_remap?: boolean ;
rmsnorm_no_half2?: boolean ;
rope_no_half2?: boolean ;
matmul_no_half2?: boolean ;
silu_no_half2?: boolean ;
concurrent_streams?: boolean ;
}
// class GeneratorItem(BaseModel):
// temperature: Union[float, None]
// top_k: Union[int, None]
// top_p: Union[float, None]
// min_p: Union[float, None]
// typical: Union[float, None]
// token_repetition_penalty_max: Union[float, None]
// token_repetition_penalty_sustain: Union[int, None]
// token_repetition_penalty_decay: Union[int, None]
// beams: Union[int, None]
// beam_length: Union[int, None]
// disallowed_tokens: Union[list[int], None]
// prompt: str
// max_new_tokens: Union[int, None]
interface LocalGeneratorItem {
temperature?: number;
top_k?: number;
top_p?: number;
min_p?: number;
typical?: number;
token_repetition_penalty_max?: number;
token_repetition_penalty_sustain?: number;
token_repetition_penalty_decay?: number;
beams?: number;
beam_length?: number;
disallowed_tokens?: number[];
prompt: string;
max_new_tokens?: number;
}
export async function loadExllamaFull(){
try {
await startLocalModelServer()
if(await checkLocalServerInstalled()){
alertWait("Loading exllama")
}
else{
alertWait("Installing & Loading exllama, this would take a while for the first time")
}
while(true){
//check is server is running by fetching the status
try {
const res = await globalFetch("http://localhost:7239/")
if(res.ok){
break
}
} catch (error) {}
await sleep(1000)
}
const body:LocalLoaderItem = {
dir: "exllama",
}
alertWait("Loading Local Model")
const res = await globalFetch("http://localhost:7239/load/", {
body: body
})
alertClear()
} catch (error) {
alertError("Error when loading Exllama: " + error)
}
}
export async function runLocalModel(prompt:string){
const db = get(DataBase)
const body:LocalGeneratorItem = {
prompt: prompt,
temperature: db.temperature,
top_k: db.ooba.top_k,
top_p: db.ooba.top_p,
typical: db.ooba.typical_p,
max_new_tokens: db.maxResponse
}
const gen = await globalFetch("http://localhost:7239/generate/", {
body: body
})
console.log(gen)
}