Experimental llamacpp support
This commit is contained in:
@@ -1,31 +0,0 @@
|
|||||||
from llama_cpp import Llama
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class LlamaItem(BaseModel):
|
|
||||||
prompt: str
|
|
||||||
model_path: str
|
|
||||||
temperature: float = 0.2,
|
|
||||||
top_p: float = 0.95,
|
|
||||||
top_k: int = 40,
|
|
||||||
max_tokens: int = 256,
|
|
||||||
presence_penalty: float = 0,
|
|
||||||
frequency_penalty: float = 0,
|
|
||||||
repeat_penalty: float = 1.1,
|
|
||||||
n_ctx: int = 2000
|
|
||||||
|
|
||||||
def stream_chat_llamacpp(item:LlamaItem):
|
|
||||||
if last_model_path != item.model_path or llm is None or n_ctx != item.n_ctx:
|
|
||||||
llm = Llama(model_path=item.model_path, n_ctx=n_ctx)
|
|
||||||
last_model_path = item.model_path
|
|
||||||
n_ctx = item.n_ctx
|
|
||||||
chunks = llm.create_completion(
|
|
||||||
prompt = item.prompt,
|
|
||||||
)
|
|
||||||
for chunk in chunks:
|
|
||||||
cont = chunk
|
|
||||||
print(cont, end="")
|
|
||||||
yield cont.encode()
|
|
||||||
|
|
||||||
n_ctx = 2000
|
|
||||||
last_model_path = ""
|
|
||||||
llm:Llama
|
|
||||||
@@ -1,12 +1,17 @@
|
|||||||
from fastapi import FastAPI, Header
|
from fastapi import FastAPI, Header
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from llamacpp import LlamaItem, stream_chat_llamacpp
|
from llama_cpp import Llama, CompletionChunk
|
||||||
from typing import Annotated, Union
|
from pydantic import BaseModel
|
||||||
|
from typing import Annotated, Union, List
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Write key for authentication
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
key_dir = os.path.join(os.getcwd(), "key.txt")
|
key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
|
||||||
if not os.path.exists(key_dir):
|
if not os.path.exists(key_dir):
|
||||||
f = open(key_dir, 'w')
|
f = open(key_dir, 'w')
|
||||||
f.write(str(uuid.uuid4()))
|
f.write(str(uuid.uuid4()))
|
||||||
@@ -15,11 +20,14 @@ f = open(key_dir, 'r')
|
|||||||
key = f.read()
|
key = f.read()
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
@app.post("/llamacpp")
|
app.add_middleware(
|
||||||
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None):
|
CORSMiddleware,
|
||||||
if key != x_risu_auth:
|
allow_origins=["*"],
|
||||||
return {"error": "Invalid key"}
|
allow_methods=["*"],
|
||||||
return StreamingResponse(stream_chat_llamacpp(item))
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Authentication endpoint
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def autha():
|
async def autha():
|
||||||
@@ -28,3 +36,69 @@ async def autha():
|
|||||||
@app.get("/auth")
|
@app.get("/auth")
|
||||||
async def auth():
|
async def auth():
|
||||||
return {"dir": key_dir}
|
return {"dir": key_dir}
|
||||||
|
|
||||||
|
|
||||||
|
# Llamacpp endpoint
|
||||||
|
|
||||||
|
class LlamaItem(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
model_path: str
|
||||||
|
temperature: float
|
||||||
|
top_p: float
|
||||||
|
top_k: int
|
||||||
|
max_tokens: int
|
||||||
|
presence_penalty: float
|
||||||
|
frequency_penalty: float
|
||||||
|
repeat_penalty: float
|
||||||
|
n_ctx: int
|
||||||
|
stop: List[str]
|
||||||
|
|
||||||
|
app.n_ctx = 2000
|
||||||
|
app.last_model_path = ""
|
||||||
|
app.llm:Llama = None
|
||||||
|
|
||||||
|
|
||||||
|
def stream_chat_llamacpp(item:LlamaItem):
|
||||||
|
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
|
||||||
|
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
|
||||||
|
app.last_model_path = item.model_path
|
||||||
|
app.n_ctx = item.n_ctx
|
||||||
|
chunks = app.llm.create_completion(
|
||||||
|
prompt = item.prompt,
|
||||||
|
temperature = item.temperature,
|
||||||
|
top_p = item.top_p,
|
||||||
|
top_k = item.top_k,
|
||||||
|
max_tokens = item.max_tokens,
|
||||||
|
presence_penalty = item.presence_penalty,
|
||||||
|
frequency_penalty = item.frequency_penalty,
|
||||||
|
repeat_penalty = item.repeat_penalty,
|
||||||
|
stop=item.stop,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
for chunk in chunks:
|
||||||
|
cont:CompletionChunk = chunk
|
||||||
|
encoded = cont["choices"][0]["text"]
|
||||||
|
print(encoded, end="")
|
||||||
|
yield encoded
|
||||||
|
|
||||||
|
@app.post("/llamacpp")
|
||||||
|
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
|
||||||
|
if key != x_risu_auth:
|
||||||
|
return {"error": "Invalid key"}
|
||||||
|
return StreamingResponse(stream_chat_llamacpp(item))
|
||||||
|
|
||||||
|
class LlamaTokenizeItem(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
model_path: str
|
||||||
|
n_ctx: int
|
||||||
|
|
||||||
|
@app.post("/llamacpp/tokenize")
|
||||||
|
async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
|
||||||
|
if key != x_risu_auth:
|
||||||
|
return {"error": "Invalid key"}
|
||||||
|
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
|
||||||
|
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
|
||||||
|
app.last_model_path = item.model_path
|
||||||
|
app.n_ctx = item.n_ctx
|
||||||
|
|
||||||
|
return app.llm.tokenize(item.prompt.encode('utf-8'))
|
||||||
@@ -145,8 +145,8 @@
|
|||||||
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('reverse_proxy')}}>Reverse Proxy</button>
|
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('reverse_proxy')}}>Reverse Proxy</button>
|
||||||
{#if import.meta.env.DEV}
|
{#if import.meta.env.DEV}
|
||||||
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={async () => {
|
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={async () => {
|
||||||
changeModel('local_gptq')
|
changeModel('local_') // TODO: Fix this
|
||||||
}}>Local Model GPTQ <Help key="experimental"/> </button>
|
}}>Local GGUF Model <Help key="experimental"/> </button>
|
||||||
{/if}
|
{/if}
|
||||||
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('ooba')}}>Oobabooga</button>
|
<button class="hover:bg-selected px-6 py-2 text-lg" on:click={() => {changeModel('ooba')}}>Oobabooga</button>
|
||||||
{#if showUnrec}
|
{#if showUnrec}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { invoke } from "@tauri-apps/api/tauri";
|
|||||||
import { globalFetch } from "src/ts/storage/globalApi";
|
import { globalFetch } from "src/ts/storage/globalApi";
|
||||||
import { sleep } from "src/ts/util";
|
import { sleep } from "src/ts/util";
|
||||||
import * as path from "@tauri-apps/api/path";
|
import * as path from "@tauri-apps/api/path";
|
||||||
import { exists } from "@tauri-apps/api/fs";
|
import { exists, readTextFile } from "@tauri-apps/api/fs";
|
||||||
import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert";
|
import { alertClear, alertError, alertMd, alertWait } from "src/ts/alert";
|
||||||
import { get } from "svelte/store";
|
import { get } from "svelte/store";
|
||||||
import { DataBase } from "src/ts/storage/database";
|
import { DataBase } from "src/ts/storage/database";
|
||||||
@@ -130,7 +130,7 @@ export async function loadExllamaFull(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
export async function runLocalModel(prompt:string){
|
async function runLocalModelOld(prompt:string){
|
||||||
const db = get(DataBase)
|
const db = get(DataBase)
|
||||||
|
|
||||||
if(!serverRunning){
|
if(!serverRunning){
|
||||||
@@ -155,10 +155,33 @@ export async function runLocalModel(prompt:string){
|
|||||||
console.log(gen)
|
console.log(gen)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let initPython = false
|
||||||
export async function installPython(){
|
export async function installPython(){
|
||||||
|
if(initPython){
|
||||||
|
return
|
||||||
|
}
|
||||||
|
initPython = true
|
||||||
const appDir = await path.appDataDir()
|
const appDir = await path.appDataDir()
|
||||||
const completedPath = await path.join(appDir, 'python', 'completed.txt')
|
const completedPath = await path.join(appDir, 'python', 'completed.txt')
|
||||||
if(await exists(completedPath)){
|
if(await exists(completedPath)){
|
||||||
|
alertWait("Python is already installed, skipping")
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
alertWait("Installing Python")
|
||||||
|
await invoke("install_python", {
|
||||||
|
path: appDir
|
||||||
|
})
|
||||||
|
alertWait("Installing Pip")
|
||||||
|
await invoke("install_pip", {
|
||||||
|
path: appDir
|
||||||
|
})
|
||||||
|
alertWait("Rewriting requirements")
|
||||||
|
await invoke('post_py_install', {
|
||||||
|
path: appDir
|
||||||
|
})
|
||||||
|
|
||||||
|
alertClear()
|
||||||
|
}
|
||||||
const dependencies = [
|
const dependencies = [
|
||||||
'pydantic',
|
'pydantic',
|
||||||
'scikit-build',
|
'scikit-build',
|
||||||
@@ -177,26 +200,94 @@ export async function installPython(){
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const srvPath = await resolveResource('/src-python/')
|
|
||||||
await invoke('run_py_server', {
|
await invoke('run_py_server', {
|
||||||
pyPath: appDir,
|
pyPath: appDir,
|
||||||
})
|
})
|
||||||
alertMd("Python Server is running at: " + srvPath)
|
await sleep(4000)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
alertWait("Installing Python")
|
|
||||||
await invoke("install_python", {
|
|
||||||
path: appDir
|
|
||||||
})
|
|
||||||
alertWait("Installing Pip")
|
|
||||||
await invoke("install_pip", {
|
|
||||||
path: appDir
|
|
||||||
})
|
|
||||||
alertWait("Rewriting requirements")
|
|
||||||
await invoke('post_py_install', {
|
|
||||||
path: appDir
|
|
||||||
})
|
|
||||||
|
|
||||||
alertClear()
|
alertClear()
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getLocalKey(retry = true) {
|
||||||
|
try {
|
||||||
|
const ft = await fetch("http://localhost:10026/")
|
||||||
|
const keyJson = await ft.json()
|
||||||
|
const keyPath = keyJson.dir
|
||||||
|
const key = await readTextFile(keyPath)
|
||||||
|
return key
|
||||||
|
} catch (error) {
|
||||||
|
if(!retry){
|
||||||
|
throw `Error when getting local key: ${error}`
|
||||||
|
}
|
||||||
|
//if is cors error
|
||||||
|
if(
|
||||||
|
error.message.includes("NetworkError when attempting to fetch resource.")
|
||||||
|
|| error.message.includes("Failed to fetch")
|
||||||
|
){
|
||||||
|
await installPython()
|
||||||
|
return await getLocalKey(false)
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw `Error when getting local key: ${error}`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function runGGUFModel(arg:{
|
||||||
|
prompt: string
|
||||||
|
modelPath: string
|
||||||
|
temperature: number
|
||||||
|
top_p: number
|
||||||
|
top_k: number
|
||||||
|
maxTokens: number
|
||||||
|
presencePenalty: number
|
||||||
|
frequencyPenalty: number
|
||||||
|
repeatPenalty: number
|
||||||
|
maxContext: number
|
||||||
|
stop: string[]
|
||||||
|
}) {
|
||||||
|
const key = await getLocalKey()
|
||||||
|
const b = await fetch("http://localhost:10026/llamacpp", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-risu-auth": key
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
prompt: arg.prompt,
|
||||||
|
model_path: arg.modelPath,
|
||||||
|
temperature: arg.temperature,
|
||||||
|
top_p: arg.top_p,
|
||||||
|
top_k: arg.top_k,
|
||||||
|
max_tokens: arg.maxTokens,
|
||||||
|
presence_penalty: arg.presencePenalty,
|
||||||
|
frequency_penalty: arg.frequencyPenalty,
|
||||||
|
repeat_penalty: arg.repeatPenalty,
|
||||||
|
n_ctx: arg.maxContext,
|
||||||
|
stop: arg.stop
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return b.body
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function tokenizeGGUFModel(prompt:string):Promise<number[]> {
|
||||||
|
const key = await getLocalKey()
|
||||||
|
const db = get(DataBase)
|
||||||
|
const modelPath = db.aiModel.replace('local_', '')
|
||||||
|
const b = await fetch("http://localhost:10026/llamacpp/tokenize", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-risu-auth": key
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
prompt: prompt,
|
||||||
|
n_ctx: db.maxContext,
|
||||||
|
model_path: modelPath
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return await b.json()
|
||||||
}
|
}
|
||||||
@@ -10,7 +10,7 @@ import { createDeep } from "./deepai";
|
|||||||
import { hubURL } from "../characterCards";
|
import { hubURL } from "../characterCards";
|
||||||
import { NovelAIBadWordIds, stringlizeNAIChat } from "./models/nai";
|
import { NovelAIBadWordIds, stringlizeNAIChat } from "./models/nai";
|
||||||
import { strongBan, tokenizeNum } from "../tokenizer";
|
import { strongBan, tokenizeNum } from "../tokenizer";
|
||||||
import { runLocalModel } from "./models/local";
|
import { runGGUFModel } from "./models/local";
|
||||||
import { risuChatParser } from "../parser";
|
import { risuChatParser } from "../parser";
|
||||||
import { SignatureV4 } from "@smithy/signature-v4";
|
import { SignatureV4 } from "@smithy/signature-v4";
|
||||||
import { HttpRequest } from "@smithy/protocol-http";
|
import { HttpRequest } from "@smithy/protocol-http";
|
||||||
@@ -1685,7 +1685,36 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model'
|
|||||||
const suggesting = model === "submodel"
|
const suggesting = model === "submodel"
|
||||||
const proompt = stringlizeChatOba(formated, currentChar.name, suggesting, arg.continue)
|
const proompt = stringlizeChatOba(formated, currentChar.name, suggesting, arg.continue)
|
||||||
const stopStrings = getStopStrings(suggesting)
|
const stopStrings = getStopStrings(suggesting)
|
||||||
await runLocalModel(proompt)
|
const modelPath = aiModel.replace('local_', '')
|
||||||
|
const res = await runGGUFModel({
|
||||||
|
prompt: proompt,
|
||||||
|
modelPath: modelPath,
|
||||||
|
temperature: temperature,
|
||||||
|
top_p: db.top_p,
|
||||||
|
top_k: db.top_k,
|
||||||
|
maxTokens: maxTokens,
|
||||||
|
presencePenalty: arg.PresensePenalty || (db.PresensePenalty / 100),
|
||||||
|
frequencyPenalty: arg.frequencyPenalty || (db.frequencyPenalty / 100),
|
||||||
|
repeatPenalty: 0,
|
||||||
|
maxContext: db.maxContext,
|
||||||
|
stop: stopStrings,
|
||||||
|
})
|
||||||
|
let decoded = ''
|
||||||
|
const transtream = new TransformStream<Uint8Array, StreamResponseChunk>({
|
||||||
|
async transform(chunk, control) {
|
||||||
|
const decodedChunk = new TextDecoder().decode(chunk)
|
||||||
|
decoded += decodedChunk
|
||||||
|
control.enqueue({
|
||||||
|
"0": decoded
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
res.pipeTo(transtream.writable)
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: 'streaming',
|
||||||
|
result: transtream.readable
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
type: 'fail',
|
type: 'fail',
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import { get } from "svelte/store";
|
|||||||
import type { OpenAIChat } from "./process";
|
import type { OpenAIChat } from "./process";
|
||||||
import { supportsInlayImage } from "./image";
|
import { supportsInlayImage } from "./image";
|
||||||
import { risuChatParser } from "./parser";
|
import { risuChatParser } from "./parser";
|
||||||
|
import { tokenizeGGUFModel } from "./process/models/local";
|
||||||
|
|
||||||
async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
|
async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
|
||||||
let db = get(DataBase)
|
let db = get(DataBase)
|
||||||
@@ -21,12 +22,14 @@ async function encode(data:string):Promise<(number[]|Uint32Array|Int32Array)>{
|
|||||||
if(db.aiModel.startsWith('mistral')){
|
if(db.aiModel.startsWith('mistral')){
|
||||||
return await tokenizeWebTokenizers(data, 'mistral')
|
return await tokenizeWebTokenizers(data, 'mistral')
|
||||||
}
|
}
|
||||||
if(db.aiModel.startsWith('local_') ||
|
if(db.aiModel === 'mancer' ||
|
||||||
db.aiModel === 'mancer' ||
|
|
||||||
db.aiModel === 'textgen_webui' ||
|
db.aiModel === 'textgen_webui' ||
|
||||||
(db.aiModel === 'reverse_proxy' && db.reverseProxyOobaMode)){
|
(db.aiModel === 'reverse_proxy' && db.reverseProxyOobaMode)){
|
||||||
return await tokenizeWebTokenizers(data, 'llama')
|
return await tokenizeWebTokenizers(data, 'llama')
|
||||||
}
|
}
|
||||||
|
if(db.aiModel.startsWith('local_')){
|
||||||
|
return await tokenizeGGUFModel(data)
|
||||||
|
}
|
||||||
if(db.aiModel === 'ooba'){
|
if(db.aiModel === 'ooba'){
|
||||||
if(db.reverseProxyOobaArgs.tokenizer === 'mixtral' || db.reverseProxyOobaArgs.tokenizer === 'mistral'){
|
if(db.reverseProxyOobaArgs.tokenizer === 'mixtral' || db.reverseProxyOobaArgs.tokenizer === 'mistral'){
|
||||||
return await tokenizeWebTokenizers(data, 'mistral')
|
return await tokenizeWebTokenizers(data, 'mistral')
|
||||||
|
|||||||
Reference in New Issue
Block a user