Files
risuai/src/ts/process/models/local.ts
2023-08-18 19:38:52 +09:00

155 lines
4.3 KiB
TypeScript

import { invoke } from "@tauri-apps/api/tauri";
import { globalFetch } from "src/ts/storage/globalApi";
import { sleep } from "src/ts/util";
import * as path from "@tauri-apps/api/path";
import { exists } from "@tauri-apps/api/fs";
import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert";
import { get } from "svelte/store";
import { DataBase } from "src/ts/storage/database";
let serverRunning = false;
export function checkLocalModel():Promise<string>{
return invoke("check_requirements_local")
}
export async function startLocalModelServer(){
if(!serverRunning){
serverRunning = true
await invoke("run_server_local")
}
return
}
export async function checkLocalServerInstalled() {
console.log(await path.appDataDir())
const p = await path.join(await path.appDataDir(), 'local_server', 'key.txt')
return await exists(p)
}
export interface LocalLoaderItem {
dir: string;
max_seq_len?: number ;
max_input_len?: number ;
max_attention_size?: number ;
compress_pos_emb?: number ;
alpha_value?: number ;
gpu_peer_fixed?: boolean ;
auto_map?: boolean ;
use_flash_attn_2?: boolean ;
matmul_recons_thd?: number ;
fused_mlp_thd?: number ;
sdp_thd?: number ;
fused_attn?: boolean ;
matmul_fused_remap?: boolean ;
rmsnorm_no_half2?: boolean ;
rope_no_half2?: boolean ;
matmul_no_half2?: boolean ;
silu_no_half2?: boolean ;
concurrent_streams?: boolean ;
}
// class GeneratorItem(BaseModel):
// temperature: Union[float, None]
// top_k: Union[int, None]
// top_p: Union[float, None]
// min_p: Union[float, None]
// typical: Union[float, None]
// token_repetition_penalty_max: Union[float, None]
// token_repetition_penalty_sustain: Union[int, None]
// token_repetition_penalty_decay: Union[int, None]
// beams: Union[int, None]
// beam_length: Union[int, None]
// disallowed_tokens: Union[list[int], None]
// prompt: str
// max_new_tokens: Union[int, None]
interface LocalGeneratorItem {
temperature?: number;
top_k?: number;
top_p?: number;
min_p?: number;
typical?: number;
token_repetition_penalty_max?: number;
token_repetition_penalty_sustain?: number;
token_repetition_penalty_decay?: number;
beams?: number;
beam_length?: number;
disallowed_tokens?: number[];
prompt: string;
max_new_tokens?: number;
}
export async function checkServerRunning() {
try {
console.log("checking server")
const res = await fetch("http://localhost:7239/")
console.log(res)
return res.ok
} catch (error) {
return false
}
}
export async function loadExllamaFull(){
try {
await startLocalModelServer()
if(await checkLocalServerInstalled()){
alertWait("Loading exllama")
}
else{
alertWait("Installing & Loading exllama, this would take a while for the first time")
}
while(true){
//check is server is running by fetching the status
if(await checkLocalServerInstalled()){
await sleep(1000)
try {
const res = await fetch("http://localhost:7239/")
if(res.status === 200){
break
}
} catch (error) {}
}
await sleep(1000)
}
const body:LocalLoaderItem = {
dir: "C:\\Users\\blueb\\Downloads\\model",
}
alertWait("Loading Local Model")
const res = await globalFetch("http://localhost:7239/load/", {
body: body
})
alertClear()
} catch (error) {
alertError("Error when loading Exllama: " + error)
}
}
export async function runLocalModel(prompt:string){
const db = get(DataBase)
if(!serverRunning){
await loadExllamaFull()
}
const body:LocalGeneratorItem = {
prompt: prompt,
temperature: db.temperature,
top_k: db.ooba.top_k,
top_p: db.ooba.top_p,
typical: db.ooba.typical_p,
max_new_tokens: db.maxResponse
}
console.log("generating")
const gen = await globalFetch("http://localhost:7239/generate/", {
body: body
})
console.log(gen)
}