diff --git a/src-tauri/src-python/llamacpp.py b/src-tauri/src-python/llamacpp.py
deleted file mode 100644
index 2c987eb3..00000000
--- a/src-tauri/src-python/llamacpp.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from llama_cpp import Llama
-from pydantic import BaseModel
-
-class LlamaItem(BaseModel):
- prompt: str
- model_path: str
- temperature: float = 0.2,
- top_p: float = 0.95,
- top_k: int = 40,
- max_tokens: int = 256,
- presence_penalty: float = 0,
- frequency_penalty: float = 0,
- repeat_penalty: float = 1.1,
- n_ctx: int = 2000
-
-def stream_chat_llamacpp(item:LlamaItem):
- if last_model_path != item.model_path or llm is None or n_ctx != item.n_ctx:
- llm = Llama(model_path=item.model_path, n_ctx=n_ctx)
- last_model_path = item.model_path
- n_ctx = item.n_ctx
- chunks = llm.create_completion(
- prompt = item.prompt,
- )
- for chunk in chunks:
- cont = chunk
- print(cont, end="")
- yield cont.encode()
-
-n_ctx = 2000
-last_model_path = ""
-llm:Llama
\ No newline at end of file
diff --git a/src-tauri/src-python/main.py b/src-tauri/src-python/main.py
index 3e08d7f9..cb49a1ce 100644
--- a/src-tauri/src-python/main.py
+++ b/src-tauri/src-python/main.py
@@ -1,12 +1,17 @@
from fastapi import FastAPI, Header
from fastapi.responses import StreamingResponse
-from llamacpp import LlamaItem, stream_chat_llamacpp
-from typing import Annotated, Union
+from llama_cpp import Llama, CompletionChunk
+from pydantic import BaseModel
+from typing import Annotated, Union, List
+from fastapi.middleware.cors import CORSMiddleware
import uuid
import os
+import sys
+
+# Write key for authentication
app = FastAPI()
-key_dir = os.path.join(os.getcwd(), "key.txt")
+key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
if not os.path.exists(key_dir):
f = open(key_dir, 'w')
f.write(str(uuid.uuid4()))
@@ -15,11 +20,14 @@ f = open(key_dir, 'r')
key = f.read()
f.close()
-@app.post("/llamacpp")
-async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None):
- if key != x_risu_auth:
- return {"error": "Invalid key"}
- return StreamingResponse(stream_chat_llamacpp(item))
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+# Authentication endpoint
@app.get("/")
async def autha():
@@ -27,4 +35,70 @@ async def autha():
@app.get("/auth")
async def auth():
- return {"dir": key_dir}
\ No newline at end of file
+ return {"dir": key_dir}
+
+
+# Llamacpp endpoint
+
+class LlamaItem(BaseModel):
+ prompt: str
+ model_path: str
+ temperature: float
+ top_p: float
+ top_k: int
+ max_tokens: int
+ presence_penalty: float
+ frequency_penalty: float
+ repeat_penalty: float
+ n_ctx: int
+ stop: List[str]
+
+app.n_ctx = 2000
+app.last_model_path = ""
+app.llm:Llama = None
+
+
+def stream_chat_llamacpp(item:LlamaItem):
+ if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
+ app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
+ app.last_model_path = item.model_path
+ app.n_ctx = item.n_ctx
+ chunks = app.llm.create_completion(
+ prompt = item.prompt,
+ temperature = item.temperature,
+ top_p = item.top_p,
+ top_k = item.top_k,
+ max_tokens = item.max_tokens,
+ presence_penalty = item.presence_penalty,
+ frequency_penalty = item.frequency_penalty,
+ repeat_penalty = item.repeat_penalty,
+ stop=item.stop,
+ stream=True
+ )
+ for chunk in chunks:
+ cont:CompletionChunk = chunk
+ encoded = cont["choices"][0]["text"]
+ print(encoded, end="")
+ yield encoded
+
+@app.post("/llamacpp")
+async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
+ if key != x_risu_auth:
+ return {"error": "Invalid key"}
+ return StreamingResponse(stream_chat_llamacpp(item))
+
+class LlamaTokenizeItem(BaseModel):
+ prompt: str
+ model_path: str
+ n_ctx: int
+
+@app.post("/llamacpp/tokenize")
+async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
+ if key != x_risu_auth:
+ return {"error": "Invalid key"}
+ if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
+ app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
+ app.last_model_path = item.model_path
+ app.n_ctx = item.n_ctx
+
+ return app.llm.tokenize(item.prompt.encode('utf-8'))
\ No newline at end of file
diff --git a/src/lib/UI/ModelList.svelte b/src/lib/UI/ModelList.svelte
index a06e6955..bc2c40d5 100644
--- a/src/lib/UI/ModelList.svelte
+++ b/src/lib/UI/ModelList.svelte
@@ -145,8 +145,8 @@
{#if import.meta.env.DEV}
+ changeModel('local_') // TODO: Fix this
+ }}>Local GGUF Model
{/if}
{#if showUnrec}
diff --git a/src/ts/process/models/local.ts b/src/ts/process/models/local.ts
index 3b7e3d88..d68a34db 100644
--- a/src/ts/process/models/local.ts
+++ b/src/ts/process/models/local.ts
@@ -2,7 +2,7 @@ import { invoke } from "@tauri-apps/api/tauri";
import { globalFetch } from "src/ts/storage/globalApi";
import { sleep } from "src/ts/util";
import * as path from "@tauri-apps/api/path";
-import { exists } from "@tauri-apps/api/fs";
+import { exists, readTextFile } from "@tauri-apps/api/fs";
import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert";
import { get } from "svelte/store";
import { DataBase } from "src/ts/storage/database";
@@ -130,7 +130,7 @@ export async function loadExllamaFull(){
}
-export async function runLocalModel(prompt:string){
+async function runLocalModelOld(prompt:string){
const db = get(DataBase)
if(!serverRunning){
@@ -155,48 +155,139 @@ export async function runLocalModel(prompt:string){
console.log(gen)
}
+let initPython = false
export async function installPython(){
+ if(initPython){
+ return
+ }
+ initPython = true
const appDir = await path.appDataDir()
const completedPath = await path.join(appDir, 'python', 'completed.txt')
if(await exists(completedPath)){
- const dependencies = [
- 'pydantic',
- 'scikit-build',
- 'scikit-build-core',
- 'pyproject_metadata',
- 'pathspec',
- 'llama-cpp-python',
- 'uvicorn[standard]',
- 'fastapi'
- ]
- for(const dep of dependencies){
- alertWait("Installing Python Dependencies (" + dep + ")")
- await invoke('install_py_dependencies', {
- path: appDir,
- dependency: dep
- })
- }
-
- const srvPath = await resolveResource('/src-python/')
- await invoke('run_py_server', {
- pyPath: appDir,
+ alertWait("Python is already installed, skipping")
+ }
+ else{
+ alertWait("Installing Python")
+ await invoke("install_python", {
+ path: appDir
+ })
+ alertWait("Installing Pip")
+ await invoke("install_pip", {
+ path: appDir
+ })
+ alertWait("Rewriting requirements")
+ await invoke('post_py_install', {
+ path: appDir
+ })
+
+ alertClear()
+ }
+ const dependencies = [
+ 'pydantic',
+ 'scikit-build',
+ 'scikit-build-core',
+ 'pyproject_metadata',
+ 'pathspec',
+ 'llama-cpp-python',
+ 'uvicorn[standard]',
+ 'fastapi'
+ ]
+ for(const dep of dependencies){
+ alertWait("Installing Python Dependencies (" + dep + ")")
+ await invoke('install_py_dependencies', {
+ path: appDir,
+ dependency: dep
})
- alertMd("Python Server is running at: " + srvPath)
- return
}
- alertWait("Installing Python")
- await invoke("install_python", {
- path: appDir
+ await invoke('run_py_server', {
+ pyPath: appDir,
})
- alertWait("Installing Pip")
- await invoke("install_pip", {
- path: appDir
- })
- alertWait("Rewriting requirements")
- await invoke('post_py_install', {
- path: appDir
+ await sleep(4000)
+ alertClear()
+ return
+
+}
+
+export async function getLocalKey(retry = true) {
+ try {
+ const ft = await fetch("http://localhost:10026/")
+ const keyJson = await ft.json()
+ const keyPath = keyJson.dir
+ const key = await readTextFile(keyPath)
+ return key
+ } catch (error) {
+ if(!retry){
+ throw `Error when getting local key: ${error}`
+ }
+ //if is cors error
+ if(
+ error.message.includes("NetworkError when attempting to fetch resource.")
+ || error.message.includes("Failed to fetch")
+ ){
+ await installPython()
+ return await getLocalKey(false)
+ }
+ else{
+ throw `Error when getting local key: ${error}`
+ }
+ }
+}
+
+export async function runGGUFModel(arg:{
+ prompt: string
+ modelPath: string
+ temperature: number
+ top_p: number
+ top_k: number
+ maxTokens: number
+ presencePenalty: number
+ frequencyPenalty: number
+ repeatPenalty: number
+ maxContext: number
+ stop: string[]
+}) {
+ const key = await getLocalKey()
+ const b = await fetch("http://localhost:10026/llamacpp", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ "x-risu-auth": key
+ },
+ body: JSON.stringify({
+ prompt: arg.prompt,
+ model_path: arg.modelPath,
+ temperature: arg.temperature,
+ top_p: arg.top_p,
+ top_k: arg.top_k,
+ max_tokens: arg.maxTokens,
+ presence_penalty: arg.presencePenalty,
+ frequency_penalty: arg.frequencyPenalty,
+ repeat_penalty: arg.repeatPenalty,
+ n_ctx: arg.maxContext,
+ stop: arg.stop
+ })
})
- alertClear()
+ return b.body
+}
+
+export async function tokenizeGGUFModel(prompt:string):Promise {
+ const key = await getLocalKey()
+ const db = get(DataBase)
+ const modelPath = db.aiModel.replace('local_', '')
+ const b = await fetch("http://localhost:10026/llamacpp/tokenize", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ "x-risu-auth": key
+ },
+ body: JSON.stringify({
+ prompt: prompt,
+ n_ctx: db.maxContext,
+ model_path: modelPath
+ })
+ })
+
+ return await b.json()
}
\ No newline at end of file
diff --git a/src/ts/process/request.ts b/src/ts/process/request.ts
index 43907563..e2abe7d5 100644
--- a/src/ts/process/request.ts
+++ b/src/ts/process/request.ts
@@ -10,7 +10,7 @@ import { createDeep } from "./deepai";
import { hubURL } from "../characterCards";
import { NovelAIBadWordIds, stringlizeNAIChat } from "./models/nai";
import { strongBan, tokenizeNum } from "../tokenizer";
-import { runLocalModel } from "./models/local";
+import { runGGUFModel } from "./models/local";
import { risuChatParser } from "../parser";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
@@ -1685,7 +1685,36 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model'
const suggesting = model === "submodel"
const proompt = stringlizeChatOba(formated, currentChar.name, suggesting, arg.continue)
const stopStrings = getStopStrings(suggesting)
- await runLocalModel(proompt)
+ const modelPath = aiModel.replace('local_', '')
+ const res = await runGGUFModel({
+ prompt: proompt,
+ modelPath: modelPath,
+ temperature: temperature,
+ top_p: db.top_p,
+ top_k: db.top_k,
+ maxTokens: maxTokens,
+ presencePenalty: arg.PresensePenalty || (db.PresensePenalty / 100),
+ frequencyPenalty: arg.frequencyPenalty || (db.frequencyPenalty / 100),
+ repeatPenalty: 0,
+ maxContext: db.maxContext,
+ stop: stopStrings,
+ })
+ let decoded = ''
+ const transtream = new TransformStream({
+ async transform(chunk, control) {
+ const decodedChunk = new TextDecoder().decode(chunk)
+ decoded += decodedChunk
+ control.enqueue({
+ "0": decoded
+ })
+ }
+ })
+ res.pipeTo(transtream.writable)
+
+ return {
+ type: 'streaming',
+ result: transtream.readable
+ }
}
return {
type: 'fail',
diff --git a/src/ts/tokenizer.ts b/src/ts/tokenizer.ts
index 0fbedb09..bc565ccd 100644
--- a/src/ts/tokenizer.ts
+++ b/src/ts/tokenizer.ts
@@ -5,6 +5,7 @@ import { get } from "svelte/store";
import type { OpenAIChat } from "./process";
import { supportsInlayImage } from "./image";
import { risuChatParser } from "./parser";
+import { tokenizeGGUFModel } from "./process/models/local";
async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
let db = get(DataBase)
@@ -21,12 +22,14 @@ async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
if(db.aiModel.startsWith('mistral')){
return await tokenizeWebTokenizers(data, 'mistral')
}
- if(db.aiModel.startsWith('local_') ||
- db.aiModel === 'mancer' ||
+ if(db.aiModel === 'mancer' ||
db.aiModel === 'textgen_webui' ||
(db.aiModel === 'reverse_proxy' && db.reverseProxyOobaMode)){
return await tokenizeWebTokenizers(data, 'llama')
}
+ if(db.aiModel.startsWith('local_')){
+ return await tokenizeGGUFModel(data)
+ }
if(db.aiModel === 'ooba'){
if(db.reverseProxyOobaArgs.tokenizer === 'mixtral' || db.reverseProxyOobaArgs.tokenizer === 'mistral'){
return await tokenizeWebTokenizers(data, 'mistral')