Experimental llamacpp support

This commit is contained in:
kwaroran
2024-01-16 10:56:23 +09:00
parent 91735d0512
commit 9db4810bbc
6 changed files with 248 additions and 82 deletions

View File

@@ -1,31 +0,0 @@
from llama_cpp import Llama
from pydantic import BaseModel
class LlamaItem(BaseModel):
prompt: str
model_path: str
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
max_tokens: int = 256,
presence_penalty: float = 0,
frequency_penalty: float = 0,
repeat_penalty: float = 1.1,
n_ctx: int = 2000
def stream_chat_llamacpp(item:LlamaItem):
if last_model_path != item.model_path or llm is None or n_ctx != item.n_ctx:
llm = Llama(model_path=item.model_path, n_ctx=n_ctx)
last_model_path = item.model_path
n_ctx = item.n_ctx
chunks = llm.create_completion(
prompt = item.prompt,
)
for chunk in chunks:
cont = chunk
print(cont, end="")
yield cont.encode()
n_ctx = 2000
last_model_path = ""
llm:Llama

View File

@@ -1,12 +1,17 @@
from fastapi import FastAPI, Header
from fastapi.responses import StreamingResponse
from llamacpp import LlamaItem, stream_chat_llamacpp
from typing import Annotated, Union
from llama_cpp import Llama, CompletionChunk
from pydantic import BaseModel
from typing import Annotated, Union, List
from fastapi.middleware.cors import CORSMiddleware
import uuid
import os
import sys
# Write key for authentication
app = FastAPI()
key_dir = os.path.join(os.getcwd(), "key.txt")
key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
if not os.path.exists(key_dir):
f = open(key_dir, 'w')
f.write(str(uuid.uuid4()))
@@ -15,11 +20,14 @@ f = open(key_dir, 'r')
key = f.read()
f.close()
@app.post("/llamacpp")
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None):
if key != x_risu_auth:
return {"error": "Invalid key"}
return StreamingResponse(stream_chat_llamacpp(item))
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Authentication endpoint
@app.get("/")
async def autha():
@@ -27,4 +35,70 @@ async def autha():
@app.get("/auth")
async def auth():
return {"dir": key_dir}
return {"dir": key_dir}
# Llamacpp endpoint
class LlamaItem(BaseModel):
prompt: str
model_path: str
temperature: float
top_p: float
top_k: int
max_tokens: int
presence_penalty: float
frequency_penalty: float
repeat_penalty: float
n_ctx: int
stop: List[str]
app.n_ctx = 2000
app.last_model_path = ""
app.llm:Llama = None
def stream_chat_llamacpp(item:LlamaItem):
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
chunks = app.llm.create_completion(
prompt = item.prompt,
temperature = item.temperature,
top_p = item.top_p,
top_k = item.top_k,
max_tokens = item.max_tokens,
presence_penalty = item.presence_penalty,
frequency_penalty = item.frequency_penalty,
repeat_penalty = item.repeat_penalty,
stop=item.stop,
stream=True
)
for chunk in chunks:
cont:CompletionChunk = chunk
encoded = cont["choices"][0]["text"]
print(encoded, end="")
yield encoded
@app.post("/llamacpp")
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
if key != x_risu_auth:
return {"error": "Invalid key"}
return StreamingResponse(stream_chat_llamacpp(item))
class LlamaTokenizeItem(BaseModel):
prompt: str
model_path: str
n_ctx: int
@app.post("/llamacpp/tokenize")
async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
if key != x_risu_auth:
return {"error": "Invalid key"}
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
return app.llm.tokenize(item.prompt.encode('utf-8'))

View File

@@ -145,8 +145,8 @@
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('reverse_proxy')}}>Reverse Proxy</button>
{#if import.meta.env.DEV}
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={async () => {
changeModel('local_gptq')
}}>Local Model GPTQ <Help key="experimental"/> </button>
changeModel('local_') // TODO: Fix this
}}>Local GGUF Model <Help key="experimental"/> </button>
{/if}
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('ooba')}}>Oobabooga</button>
{#if showUnrec}

View File

@@ -2,7 +2,7 @@ import { invoke } from "@tauri-apps/api/tauri";
import { globalFetch } from "src/ts/storage/globalApi";
import { sleep } from "src/ts/util";
import * as path from "@tauri-apps/api/path";
import { exists } from "@tauri-apps/api/fs";
import { exists, readTextFile } from "@tauri-apps/api/fs";
import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert";
import { get } from "svelte/store";
import { DataBase } from "src/ts/storage/database";
@@ -130,7 +130,7 @@ export async function loadExllamaFull(){
}
export async function runLocalModel(prompt:string){
async function runLocalModelOld(prompt:string){
const db = get(DataBase)
if(!serverRunning){
@@ -155,48 +155,139 @@ export async function runLocalModel(prompt:string){
console.log(gen)
}
let initPython = false
export async function installPython(){
if(initPython){
return
}
initPython = true
const appDir = await path.appDataDir()
const completedPath = await path.join(appDir, 'python', 'completed.txt')
if(await exists(completedPath)){
const dependencies = [
'pydantic',
'scikit-build',
'scikit-build-core',
'pyproject_metadata',
'pathspec',
'llama-cpp-python',
'uvicorn[standard]',
'fastapi'
]
for(const dep of dependencies){
alertWait("Installing Python Dependencies (" + dep + ")")
await invoke('install_py_dependencies', {
path: appDir,
dependency: dep
})
}
const srvPath = await resolveResource('/src-python/')
await invoke('run_py_server', {
pyPath: appDir,
alertWait("Python is already installed, skipping")
}
else{
alertWait("Installing Python")
await invoke("install_python", {
path: appDir
})
alertWait("Installing Pip")
await invoke("install_pip", {
path: appDir
})
alertWait("Rewriting requirements")
await invoke('post_py_install', {
path: appDir
})
alertClear()
}
const dependencies = [
'pydantic',
'scikit-build',
'scikit-build-core',
'pyproject_metadata',
'pathspec',
'llama-cpp-python',
'uvicorn[standard]',
'fastapi'
]
for(const dep of dependencies){
alertWait("Installing Python Dependencies (" + dep + ")")
await invoke('install_py_dependencies', {
path: appDir,
dependency: dep
})
alertMd("Python Server is running at: " + srvPath)
return
}
alertWait("Installing Python")
await invoke("install_python", {
path: appDir
await invoke('run_py_server', {
pyPath: appDir,
})
alertWait("Installing Pip")
await invoke("install_pip", {
path: appDir
})
alertWait("Rewriting requirements")
await invoke('post_py_install', {
path: appDir
await sleep(4000)
alertClear()
return
}
export async function getLocalKey(retry = true) {
try {
const ft = await fetch("http://localhost:10026/")
const keyJson = await ft.json()
const keyPath = keyJson.dir
const key = await readTextFile(keyPath)
return key
} catch (error) {
if(!retry){
throw `Error when getting local key: ${error}`
}
//if is cors error
if(
error.message.includes("NetworkError when attempting to fetch resource.")
|| error.message.includes("Failed to fetch")
){
await installPython()
return await getLocalKey(false)
}
else{
throw `Error when getting local key: ${error}`
}
}
}
export async function runGGUFModel(arg:{
prompt: string
modelPath: string
temperature: number
top_p: number
top_k: number
maxTokens: number
presencePenalty: number
frequencyPenalty: number
repeatPenalty: number
maxContext: number
stop: string[]
}) {
const key = await getLocalKey()
const b = await fetch("http://localhost:10026/llamacpp", {
method: "POST",
headers: {
"Content-Type": "application/json",
"x-risu-auth": key
},
body: JSON.stringify({
prompt: arg.prompt,
model_path: arg.modelPath,
temperature: arg.temperature,
top_p: arg.top_p,
top_k: arg.top_k,
max_tokens: arg.maxTokens,
presence_penalty: arg.presencePenalty,
frequency_penalty: arg.frequencyPenalty,
repeat_penalty: arg.repeatPenalty,
n_ctx: arg.maxContext,
stop: arg.stop
})
})
alertClear()
return b.body
}
export async function tokenizeGGUFModel(prompt:string):Promise<number[]> {
const key = await getLocalKey()
const db = get(DataBase)
const modelPath = db.aiModel.replace('local_', '')
const b = await fetch("http://localhost:10026/llamacpp/tokenize", {
method: "POST",
headers: {
"Content-Type": "application/json",
"x-risu-auth": key
},
body: JSON.stringify({
prompt: prompt,
n_ctx: db.maxContext,
model_path: modelPath
})
})
return await b.json()
}

View File

@@ -10,7 +10,7 @@ import { createDeep } from "./deepai";
import { hubURL } from "../characterCards";
import { NovelAIBadWordIds, stringlizeNAIChat } from "./models/nai";
import { strongBan, tokenizeNum } from "../tokenizer";
import { runLocalModel } from "./models/local";
import { runGGUFModel } from "./models/local";
import { risuChatParser } from "../parser";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
@@ -1685,7 +1685,36 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model'
const suggesting = model === "submodel"
const proompt = stringlizeChatOba(formated, currentChar.name, suggesting, arg.continue)
const stopStrings = getStopStrings(suggesting)
await runLocalModel(proompt)
const modelPath = aiModel.replace('local_', '')
const res = await runGGUFModel({
prompt: proompt,
modelPath: modelPath,
temperature: temperature,
top_p: db.top_p,
top_k: db.top_k,
maxTokens: maxTokens,
presencePenalty: arg.PresensePenalty || (db.PresensePenalty / 100),
frequencyPenalty: arg.frequencyPenalty || (db.frequencyPenalty / 100),
repeatPenalty: 0,
maxContext: db.maxContext,
stop: stopStrings,
})
let decoded = ''
const transtream = new TransformStream<Uint8Array, StreamResponseChunk>({
async transform(chunk, control) {
const decodedChunk = new TextDecoder().decode(chunk)
decoded += decodedChunk
control.enqueue({
"0": decoded
})
}
})
res.pipeTo(transtream.writable)
return {
type: 'streaming',
result: transtream.readable
}
}
return {
type: 'fail',

View File

@@ -5,6 +5,7 @@ import { get } from "svelte/store";
import type { OpenAIChat } from "./process";
import { supportsInlayImage } from "./image";
import { risuChatParser } from "./parser";
import { tokenizeGGUFModel } from "./process/models/local";
async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
let db = get(DataBase)
@@ -21,12 +22,14 @@ async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
if(db.aiModel.startsWith('mistral')){
return await tokenizeWebTokenizers(data, 'mistral')
}
if(db.aiModel.startsWith('local_') ||
db.aiModel === 'mancer' ||
if(db.aiModel === 'mancer' ||
db.aiModel === 'textgen_webui' ||
(db.aiModel === 'reverse_proxy' && db.reverseProxyOobaMode)){
return await tokenizeWebTokenizers(data, 'llama')
}
if(db.aiModel.startsWith('local_')){
return await tokenizeGGUFModel(data)
}
if(db.aiModel === 'ooba'){
if(db.reverseProxyOobaArgs.tokenizer === 'mixtral' || db.reverseProxyOobaArgs.tokenizer === 'mistral'){
return await tokenizeWebTokenizers(data, 'mistral')