From 9db4810bbcc60920241a180e1caacc836a3d900d Mon Sep 17 00:00:00 2001 From: kwaroran Date: Tue, 16 Jan 2024 10:56:23 +0900 Subject: [PATCH] Experimental llamacpp support --- src-tauri/src-python/llamacpp.py | 31 ------ src-tauri/src-python/main.py | 92 +++++++++++++++-- src/lib/UI/ModelList.svelte | 4 +- src/ts/process/models/local.ts | 163 ++++++++++++++++++++++++------- src/ts/process/request.ts | 33 ++++++- src/ts/tokenizer.ts | 7 +- 6 files changed, 248 insertions(+), 82 deletions(-) delete mode 100644 src-tauri/src-python/llamacpp.py diff --git a/src-tauri/src-python/llamacpp.py b/src-tauri/src-python/llamacpp.py deleted file mode 100644 index 2c987eb3..00000000 --- a/src-tauri/src-python/llamacpp.py +++ /dev/null @@ -1,31 +0,0 @@ -from llama_cpp import Llama -from pydantic import BaseModel - -class LlamaItem(BaseModel): - prompt: str - model_path: str - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - max_tokens: int = 256, - presence_penalty: float = 0, - frequency_penalty: float = 0, - repeat_penalty: float = 1.1, - n_ctx: int = 2000 - -def stream_chat_llamacpp(item:LlamaItem): - if last_model_path != item.model_path or llm is None or n_ctx != item.n_ctx: - llm = Llama(model_path=item.model_path, n_ctx=n_ctx) - last_model_path = item.model_path - n_ctx = item.n_ctx - chunks = llm.create_completion( - prompt = item.prompt, - ) - for chunk in chunks: - cont = chunk - print(cont, end="") - yield cont.encode() - -n_ctx = 2000 -last_model_path = "" -llm:Llama \ No newline at end of file diff --git a/src-tauri/src-python/main.py b/src-tauri/src-python/main.py index 3e08d7f9..cb49a1ce 100644 --- a/src-tauri/src-python/main.py +++ b/src-tauri/src-python/main.py @@ -1,12 +1,17 @@ from fastapi import FastAPI, Header from fastapi.responses import StreamingResponse -from llamacpp import LlamaItem, stream_chat_llamacpp -from typing import Annotated, Union +from llama_cpp import Llama, CompletionChunk +from pydantic import BaseModel +from typing import Annotated, Union, List +from fastapi.middleware.cors import CORSMiddleware import uuid import os +import sys + +# Write key for authentication app = FastAPI() -key_dir = os.path.join(os.getcwd(), "key.txt") +key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt") if not os.path.exists(key_dir): f = open(key_dir, 'w') f.write(str(uuid.uuid4())) @@ -15,11 +20,14 @@ f = open(key_dir, 'r') key = f.read() f.close() -@app.post("/llamacpp") -async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None): - if key != x_risu_auth: - return {"error": "Invalid key"} - return StreamingResponse(stream_chat_llamacpp(item)) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + +# Authentication endpoint @app.get("/") async def autha(): @@ -27,4 +35,70 @@ async def autha(): @app.get("/auth") async def auth(): - return {"dir": key_dir} \ No newline at end of file + return {"dir": key_dir} + + +# Llamacpp endpoint + +class LlamaItem(BaseModel): + prompt: str + model_path: str + temperature: float + top_p: float + top_k: int + max_tokens: int + presence_penalty: float + frequency_penalty: float + repeat_penalty: float + n_ctx: int + stop: List[str] + +app.n_ctx = 2000 +app.last_model_path = "" +app.llm:Llama = None + + +def stream_chat_llamacpp(item:LlamaItem): + if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx: + app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200) + app.last_model_path = item.model_path + app.n_ctx = item.n_ctx + chunks = app.llm.create_completion( + prompt = item.prompt, + temperature = item.temperature, + top_p = item.top_p, + top_k = item.top_k, + max_tokens = item.max_tokens, + presence_penalty = item.presence_penalty, + frequency_penalty = item.frequency_penalty, + repeat_penalty = item.repeat_penalty, + stop=item.stop, + stream=True + ) + for chunk in chunks: + cont:CompletionChunk = chunk + encoded = cont["choices"][0]["text"] + print(encoded, end="") + yield encoded + +@app.post("/llamacpp") +async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse: + if key != x_risu_auth: + return {"error": "Invalid key"} + return StreamingResponse(stream_chat_llamacpp(item)) + +class LlamaTokenizeItem(BaseModel): + prompt: str + model_path: str + n_ctx: int + +@app.post("/llamacpp/tokenize") +async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]: + if key != x_risu_auth: + return {"error": "Invalid key"} + if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx: + app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200) + app.last_model_path = item.model_path + app.n_ctx = item.n_ctx + + return app.llm.tokenize(item.prompt.encode('utf-8')) \ No newline at end of file diff --git a/src/lib/UI/ModelList.svelte b/src/lib/UI/ModelList.svelte index a06e6955..bc2c40d5 100644 --- a/src/lib/UI/ModelList.svelte +++ b/src/lib/UI/ModelList.svelte @@ -145,8 +145,8 @@ {#if import.meta.env.DEV} + changeModel('local_') // TODO: Fix this + }}>Local GGUF Model {/if} {#if showUnrec} diff --git a/src/ts/process/models/local.ts b/src/ts/process/models/local.ts index 3b7e3d88..d68a34db 100644 --- a/src/ts/process/models/local.ts +++ b/src/ts/process/models/local.ts @@ -2,7 +2,7 @@ import { invoke } from "@tauri-apps/api/tauri"; import { globalFetch } from "src/ts/storage/globalApi"; import { sleep } from "src/ts/util"; import * as path from "@tauri-apps/api/path"; -import { exists } from "@tauri-apps/api/fs"; +import { exists, readTextFile } from "@tauri-apps/api/fs"; import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert"; import { get } from "svelte/store"; import { DataBase } from "src/ts/storage/database"; @@ -130,7 +130,7 @@ export async function loadExllamaFull(){ } -export async function runLocalModel(prompt:string){ +async function runLocalModelOld(prompt:string){ const db = get(DataBase) if(!serverRunning){ @@ -155,48 +155,139 @@ export async function runLocalModel(prompt:string){ console.log(gen) } +let initPython = false export async function installPython(){ + if(initPython){ + return + } + initPython = true const appDir = await path.appDataDir() const completedPath = await path.join(appDir, 'python', 'completed.txt') if(await exists(completedPath)){ - const dependencies = [ - 'pydantic', - 'scikit-build', - 'scikit-build-core', - 'pyproject_metadata', - 'pathspec', - 'llama-cpp-python', - 'uvicorn[standard]', - 'fastapi' - ] - for(const dep of dependencies){ - alertWait("Installing Python Dependencies (" + dep + ")") - await invoke('install_py_dependencies', { - path: appDir, - dependency: dep - }) - } - - const srvPath = await resolveResource('/src-python/') - await invoke('run_py_server', { - pyPath: appDir, + alertWait("Python is already installed, skipping") + } + else{ + alertWait("Installing Python") + await invoke("install_python", { + path: appDir + }) + alertWait("Installing Pip") + await invoke("install_pip", { + path: appDir + }) + alertWait("Rewriting requirements") + await invoke('post_py_install', { + path: appDir + }) + + alertClear() + } + const dependencies = [ + 'pydantic', + 'scikit-build', + 'scikit-build-core', + 'pyproject_metadata', + 'pathspec', + 'llama-cpp-python', + 'uvicorn[standard]', + 'fastapi' + ] + for(const dep of dependencies){ + alertWait("Installing Python Dependencies (" + dep + ")") + await invoke('install_py_dependencies', { + path: appDir, + dependency: dep }) - alertMd("Python Server is running at: " + srvPath) - return } - alertWait("Installing Python") - await invoke("install_python", { - path: appDir + await invoke('run_py_server', { + pyPath: appDir, }) - alertWait("Installing Pip") - await invoke("install_pip", { - path: appDir - }) - alertWait("Rewriting requirements") - await invoke('post_py_install', { - path: appDir + await sleep(4000) + alertClear() + return + +} + +export async function getLocalKey(retry = true) { + try { + const ft = await fetch("http://localhost:10026/") + const keyJson = await ft.json() + const keyPath = keyJson.dir + const key = await readTextFile(keyPath) + return key + } catch (error) { + if(!retry){ + throw `Error when getting local key: ${error}` + } + //if is cors error + if( + error.message.includes("NetworkError when attempting to fetch resource.") + || error.message.includes("Failed to fetch") + ){ + await installPython() + return await getLocalKey(false) + } + else{ + throw `Error when getting local key: ${error}` + } + } +} + +export async function runGGUFModel(arg:{ + prompt: string + modelPath: string + temperature: number + top_p: number + top_k: number + maxTokens: number + presencePenalty: number + frequencyPenalty: number + repeatPenalty: number + maxContext: number + stop: string[] +}) { + const key = await getLocalKey() + const b = await fetch("http://localhost:10026/llamacpp", { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-risu-auth": key + }, + body: JSON.stringify({ + prompt: arg.prompt, + model_path: arg.modelPath, + temperature: arg.temperature, + top_p: arg.top_p, + top_k: arg.top_k, + max_tokens: arg.maxTokens, + presence_penalty: arg.presencePenalty, + frequency_penalty: arg.frequencyPenalty, + repeat_penalty: arg.repeatPenalty, + n_ctx: arg.maxContext, + stop: arg.stop + }) }) - alertClear() + return b.body +} + +export async function tokenizeGGUFModel(prompt:string):Promise { + const key = await getLocalKey() + const db = get(DataBase) + const modelPath = db.aiModel.replace('local_', '') + const b = await fetch("http://localhost:10026/llamacpp/tokenize", { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-risu-auth": key + }, + body: JSON.stringify({ + prompt: prompt, + n_ctx: db.maxContext, + model_path: modelPath + }) + }) + + return await b.json() } \ No newline at end of file diff --git a/src/ts/process/request.ts b/src/ts/process/request.ts index 43907563..e2abe7d5 100644 --- a/src/ts/process/request.ts +++ b/src/ts/process/request.ts @@ -10,7 +10,7 @@ import { createDeep } from "./deepai"; import { hubURL } from "../characterCards"; import { NovelAIBadWordIds, stringlizeNAIChat } from "./models/nai"; import { strongBan, tokenizeNum } from "../tokenizer"; -import { runLocalModel } from "./models/local"; +import { runGGUFModel } from "./models/local"; import { risuChatParser } from "../parser"; import { SignatureV4 } from "@smithy/signature-v4"; import { HttpRequest } from "@smithy/protocol-http"; @@ -1685,7 +1685,36 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model' const suggesting = model === "submodel" const proompt = stringlizeChatOba(formated, currentChar.name, suggesting, arg.continue) const stopStrings = getStopStrings(suggesting) - await runLocalModel(proompt) + const modelPath = aiModel.replace('local_', '') + const res = await runGGUFModel({ + prompt: proompt, + modelPath: modelPath, + temperature: temperature, + top_p: db.top_p, + top_k: db.top_k, + maxTokens: maxTokens, + presencePenalty: arg.PresensePenalty || (db.PresensePenalty / 100), + frequencyPenalty: arg.frequencyPenalty || (db.frequencyPenalty / 100), + repeatPenalty: 0, + maxContext: db.maxContext, + stop: stopStrings, + }) + let decoded = '' + const transtream = new TransformStream({ + async transform(chunk, control) { + const decodedChunk = new TextDecoder().decode(chunk) + decoded += decodedChunk + control.enqueue({ + "0": decoded + }) + } + }) + res.pipeTo(transtream.writable) + + return { + type: 'streaming', + result: transtream.readable + } } return { type: 'fail', diff --git a/src/ts/tokenizer.ts b/src/ts/tokenizer.ts index 0fbedb09..bc565ccd 100644 --- a/src/ts/tokenizer.ts +++ b/src/ts/tokenizer.ts @@ -5,6 +5,7 @@ import { get } from "svelte/store"; import type { OpenAIChat } from "./process"; import { supportsInlayImage } from "./image"; import { risuChatParser } from "./parser"; +import { tokenizeGGUFModel } from "./process/models/local"; async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{ let db = get(DataBase) @@ -21,12 +22,14 @@ async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{ if(db.aiModel.startsWith('mistral')){ return await tokenizeWebTokenizers(data, 'mistral') } - if(db.aiModel.startsWith('local_') || - db.aiModel === 'mancer' || + if(db.aiModel === 'mancer' || db.aiModel === 'textgen_webui' || (db.aiModel === 'reverse_proxy' && db.reverseProxyOobaMode)){ return await tokenizeWebTokenizers(data, 'llama') } + if(db.aiModel.startsWith('local_')){ + return await tokenizeGGUFModel(data) + } if(db.aiModel === 'ooba'){ if(db.reverseProxyOobaArgs.tokenizer === 'mixtral' || db.reverseProxyOobaArgs.tokenizer === 'mistral'){ return await tokenizeWebTokenizers(data, 'mistral')