Experimental llamacpp support

This commit is contained in:
kwaroran
2024-01-16 10:56:23 +09:00
parent 91735d0512
commit 9db4810bbc
6 changed files with 248 additions and 82 deletions

View File

@@ -1,31 +0,0 @@
from llama_cpp import Llama
from pydantic import BaseModel
class LlamaItem(BaseModel):
prompt: str
model_path: str
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
max_tokens: int = 256,
presence_penalty: float = 0,
frequency_penalty: float = 0,
repeat_penalty: float = 1.1,
n_ctx: int = 2000
def stream_chat_llamacpp(item:LlamaItem):
if last_model_path != item.model_path or llm is None or n_ctx != item.n_ctx:
llm = Llama(model_path=item.model_path, n_ctx=n_ctx)
last_model_path = item.model_path
n_ctx = item.n_ctx
chunks = llm.create_completion(
prompt = item.prompt,
)
for chunk in chunks:
cont = chunk
print(cont, end="")
yield cont.encode()
n_ctx = 2000
last_model_path = ""
llm:Llama

View File

@@ -1,12 +1,17 @@
from fastapi import FastAPI, Header from fastapi import FastAPI, Header
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from llamacpp import LlamaItem, stream_chat_llamacpp from llama_cpp import Llama, CompletionChunk
from typing import Annotated, Union from pydantic import BaseModel
from typing import Annotated, Union, List
from fastapi.middleware.cors import CORSMiddleware
import uuid import uuid
import os import os
import sys
# Write key for authentication
app = FastAPI() app = FastAPI()
key_dir = os.path.join(os.getcwd(), "key.txt") key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
if not os.path.exists(key_dir): if not os.path.exists(key_dir):
f = open(key_dir, 'w') f = open(key_dir, 'w')
f.write(str(uuid.uuid4())) f.write(str(uuid.uuid4()))
@@ -15,11 +20,14 @@ f = open(key_dir, 'r')
key = f.read() key = f.read()
f.close() f.close()
@app.post("/llamacpp") app.add_middleware(
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None): CORSMiddleware,
if key != x_risu_auth: allow_origins=["*"],
return {"error": "Invalid key"} allow_methods=["*"],
return StreamingResponse(stream_chat_llamacpp(item)) allow_headers=["*"],
)
# Authentication endpoint
@app.get("/") @app.get("/")
async def autha(): async def autha():
@@ -28,3 +36,69 @@ async def autha():
@app.get("/auth") @app.get("/auth")
async def auth(): async def auth():
return {"dir": key_dir} return {"dir": key_dir}
# Llamacpp endpoint
class LlamaItem(BaseModel):
prompt: str
model_path: str
temperature: float
top_p: float
top_k: int
max_tokens: int
presence_penalty: float
frequency_penalty: float
repeat_penalty: float
n_ctx: int
stop: List[str]
app.n_ctx = 2000
app.last_model_path = ""
app.llm:Llama = None
def stream_chat_llamacpp(item:LlamaItem):
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
chunks = app.llm.create_completion(
prompt = item.prompt,
temperature = item.temperature,
top_p = item.top_p,
top_k = item.top_k,
max_tokens = item.max_tokens,
presence_penalty = item.presence_penalty,
frequency_penalty = item.frequency_penalty,
repeat_penalty = item.repeat_penalty,
stop=item.stop,
stream=True
)
for chunk in chunks:
cont:CompletionChunk = chunk
encoded = cont["choices"][0]["text"]
print(encoded, end="")
yield encoded
@app.post("/llamacpp")
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
if key != x_risu_auth:
return {"error": "Invalid key"}
return StreamingResponse(stream_chat_llamacpp(item))
class LlamaTokenizeItem(BaseModel):
prompt: str
model_path: str
n_ctx: int
@app.post("/llamacpp/tokenize")
async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
if key != x_risu_auth:
return {"error": "Invalid key"}
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
return app.llm.tokenize(item.prompt.encode('utf-8'))

View File

@@ -145,8 +145,8 @@
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('reverse_proxy')}}>Reverse Proxy</button> <button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('reverse_proxy')}}>Reverse Proxy</button>
{#if import.meta.env.DEV} {#if import.meta.env.DEV}
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={async () => { <button class="hover:bg-selected px-6 py-2 text-lg" on:click={async () => {
changeModel('local_gptq') changeModel('local_') // TODO: Fix this
}}>Local Model GPTQ <Help key="experimental"/> </button> }}>Local GGUF Model <Help key="experimental"/> </button>
{/if} {/if}
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('ooba')}}>Oobabooga</button> <button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('ooba')}}>Oobabooga</button>
{#if showUnrec} {#if showUnrec}

View File

@@ -2,7 +2,7 @@ import { invoke } from "@tauri-apps/api/tauri";
import { globalFetch } from "src/ts/storage/globalApi"; import { globalFetch } from "src/ts/storage/globalApi";
import { sleep } from "src/ts/util"; import { sleep } from "src/ts/util";
import * as path from "@tauri-apps/api/path"; import * as path from "@tauri-apps/api/path";
import { exists } from "@tauri-apps/api/fs"; import { exists, readTextFile } from "@tauri-apps/api/fs";
import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert"; import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert";
import { get } from "svelte/store"; import { get } from "svelte/store";
import { DataBase } from "src/ts/storage/database"; import { DataBase } from "src/ts/storage/database";
@@ -130,7 +130,7 @@ export async function loadExllamaFull(){
} }
export async function runLocalModel(prompt:string){ async function runLocalModelOld(prompt:string){
const db = get(DataBase) const db = get(DataBase)
if(!serverRunning){ if(!serverRunning){
@@ -155,10 +155,33 @@ export async function runLocalModel(prompt:string){
console.log(gen) console.log(gen)
} }
let initPython = false
export async function installPython(){ export async function installPython(){
if(initPython){
return
}
initPython = true
const appDir = await path.appDataDir() const appDir = await path.appDataDir()
const completedPath = await path.join(appDir, 'python', 'completed.txt') const completedPath = await path.join(appDir, 'python', 'completed.txt')
if(await exists(completedPath)){ if(await exists(completedPath)){
alertWait("Python is already installed, skipping")
}
else{
alertWait("Installing Python")
await invoke("install_python", {
path: appDir
})
alertWait("Installing Pip")
await invoke("install_pip", {
path: appDir
})
alertWait("Rewriting requirements")
await invoke('post_py_install', {
path: appDir
})
alertClear()
}
const dependencies = [ const dependencies = [
'pydantic', 'pydantic',
'scikit-build', 'scikit-build',
@@ -177,26 +200,94 @@ export async function installPython(){
}) })
} }
const srvPath = await resolveResource('/src-python/')
await invoke('run_py_server', { await invoke('run_py_server', {
pyPath: appDir, pyPath: appDir,
}) })
alertMd("Python Server is running at: " + srvPath) await sleep(4000)
return
}
alertWait("Installing Python")
await invoke("install_python", {
path: appDir
})
alertWait("Installing Pip")
await invoke("install_pip", {
path: appDir
})
alertWait("Rewriting requirements")
await invoke('post_py_install', {
path: appDir
})
alertClear() alertClear()
return
}
export async function getLocalKey(retry = true) {
try {
const ft = await fetch("http://localhost:10026/")
const keyJson = await ft.json()
const keyPath = keyJson.dir
const key = await readTextFile(keyPath)
return key
} catch (error) {
if(!retry){
throw `Error when getting local key: ${error}`
}
//if is cors error
if(
error.message.includes("NetworkError when attempting to fetch resource.")
|| error.message.includes("Failed to fetch")
){
await installPython()
return await getLocalKey(false)
}
else{
throw `Error when getting local key: ${error}`
}
}
}
export async function runGGUFModel(arg:{
prompt: string
modelPath: string
temperature: number
top_p: number
top_k: number
maxTokens: number
presencePenalty: number
frequencyPenalty: number
repeatPenalty: number
maxContext: number
stop: string[]
}) {
const key = await getLocalKey()
const b = await fetch("http://localhost:10026/llamacpp", {
method: "POST",
headers: {
"Content-Type": "application/json",
"x-risu-auth": key
},
body: JSON.stringify({
prompt: arg.prompt,
model_path: arg.modelPath,
temperature: arg.temperature,
top_p: arg.top_p,
top_k: arg.top_k,
max_tokens: arg.maxTokens,
presence_penalty: arg.presencePenalty,
frequency_penalty: arg.frequencyPenalty,
repeat_penalty: arg.repeatPenalty,
n_ctx: arg.maxContext,
stop: arg.stop
})
})
return b.body
}
export async function tokenizeGGUFModel(prompt:string):Promise<number[]> {
const key = await getLocalKey()
const db = get(DataBase)
const modelPath = db.aiModel.replace('local_', '')
const b = await fetch("http://localhost:10026/llamacpp/tokenize", {
method: "POST",
headers: {
"Content-Type": "application/json",
"x-risu-auth": key
},
body: JSON.stringify({
prompt: prompt,
n_ctx: db.maxContext,
model_path: modelPath
})
})
return await b.json()
} }

View File

@@ -10,7 +10,7 @@ import { createDeep } from "./deepai";
import { hubURL } from "../characterCards"; import { hubURL } from "../characterCards";
import { NovelAIBadWordIds, stringlizeNAIChat } from "./models/nai"; import { NovelAIBadWordIds, stringlizeNAIChat } from "./models/nai";
import { strongBan, tokenizeNum } from "../tokenizer"; import { strongBan, tokenizeNum } from "../tokenizer";
import { runLocalModel } from "./models/local"; import { runGGUFModel } from "./models/local";
import { risuChatParser } from "../parser"; import { risuChatParser } from "../parser";
import { SignatureV4 } from "@smithy/signature-v4"; import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http"; import { HttpRequest } from "@smithy/protocol-http";
@@ -1685,7 +1685,36 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model'
const suggesting = model === "submodel" const suggesting = model === "submodel"
const proompt = stringlizeChatOba(formated, currentChar.name, suggesting, arg.continue) const proompt = stringlizeChatOba(formated, currentChar.name, suggesting, arg.continue)
const stopStrings = getStopStrings(suggesting) const stopStrings = getStopStrings(suggesting)
await runLocalModel(proompt) const modelPath = aiModel.replace('local_', '')
const res = await runGGUFModel({
prompt: proompt,
modelPath: modelPath,
temperature: temperature,
top_p: db.top_p,
top_k: db.top_k,
maxTokens: maxTokens,
presencePenalty: arg.PresensePenalty || (db.PresensePenalty / 100),
frequencyPenalty: arg.frequencyPenalty || (db.frequencyPenalty / 100),
repeatPenalty: 0,
maxContext: db.maxContext,
stop: stopStrings,
})
let decoded = ''
const transtream = new TransformStream<Uint8Array, StreamResponseChunk>({
async transform(chunk, control) {
const decodedChunk = new TextDecoder().decode(chunk)
decoded += decodedChunk
control.enqueue({
"0": decoded
})
}
})
res.pipeTo(transtream.writable)
return {
type: 'streaming',
result: transtream.readable
}
} }
return { return {
type: 'fail', type: 'fail',

View File

@@ -5,6 +5,7 @@ import { get } from "svelte/store";
import type { OpenAIChat } from "./process"; import type { OpenAIChat } from "./process";
import { supportsInlayImage } from "./image"; import { supportsInlayImage } from "./image";
import { risuChatParser } from "./parser"; import { risuChatParser } from "./parser";
import { tokenizeGGUFModel } from "./process/models/local";
async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{ async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
let db = get(DataBase) let db = get(DataBase)
@@ -21,12 +22,14 @@ async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
if(db.aiModel.startsWith('mistral')){ if(db.aiModel.startsWith('mistral')){
return await tokenizeWebTokenizers(data, 'mistral') return await tokenizeWebTokenizers(data, 'mistral')
} }
if(db.aiModel.startsWith('local_') || if(db.aiModel === 'mancer' ||
db.aiModel === 'mancer' ||
db.aiModel === 'textgen_webui' || db.aiModel === 'textgen_webui' ||
(db.aiModel === 'reverse_proxy' && db.reverseProxyOobaMode)){ (db.aiModel === 'reverse_proxy' && db.reverseProxyOobaMode)){
return await tokenizeWebTokenizers(data, 'llama') return await tokenizeWebTokenizers(data, 'llama')
} }
if(db.aiModel.startsWith('local_')){
return await tokenizeGGUFModel(data)
}
if(db.aiModel === 'ooba'){ if(db.aiModel === 'ooba'){
if(db.reverseProxyOobaArgs.tokenizer === 'mixtral' || db.reverseProxyOobaArgs.tokenizer === 'mistral'){ if(db.reverseProxyOobaArgs.tokenizer === 'mixtral' || db.reverseProxyOobaArgs.tokenizer === 'mistral'){
return await tokenizeWebTokenizers(data, 'mistral') return await tokenizeWebTokenizers(data, 'mistral')