Experimental llamacpp support

This commit is contained in:
kwaroran
2024-01-16 10:56:23 +09:00
parent 91735d0512
commit 9db4810bbc
6 changed files with 248 additions and 82 deletions

View File

@@ -1,31 +0,0 @@
from llama_cpp import Llama
from pydantic import BaseModel
class LlamaItem(BaseModel):
prompt: str
model_path: str
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
max_tokens: int = 256,
presence_penalty: float = 0,
frequency_penalty: float = 0,
repeat_penalty: float = 1.1,
n_ctx: int = 2000
def stream_chat_llamacpp(item:LlamaItem):
if last_model_path != item.model_path or llm is None or n_ctx != item.n_ctx:
llm = Llama(model_path=item.model_path, n_ctx=n_ctx)
last_model_path = item.model_path
n_ctx = item.n_ctx
chunks = llm.create_completion(
prompt = item.prompt,
)
for chunk in chunks:
cont = chunk
print(cont, end="")
yield cont.encode()
n_ctx = 2000
last_model_path = ""
llm:Llama

View File

@@ -1,12 +1,17 @@
from fastapi import FastAPI, Header
from fastapi.responses import StreamingResponse
from llamacpp import LlamaItem, stream_chat_llamacpp
from typing import Annotated, Union
from llama_cpp import Llama, CompletionChunk
from pydantic import BaseModel
from typing import Annotated, Union, List
from fastapi.middleware.cors import CORSMiddleware
import uuid
import os
import sys
# Write key for authentication
app = FastAPI()
key_dir = os.path.join(os.getcwd(), "key.txt")
key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
if not os.path.exists(key_dir):
f = open(key_dir, 'w')
f.write(str(uuid.uuid4()))
@@ -15,11 +20,14 @@ f = open(key_dir, 'r')
key = f.read()
f.close()
@app.post("/llamacpp")
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None):
if key != x_risu_auth:
return {"error": "Invalid key"}
return StreamingResponse(stream_chat_llamacpp(item))
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Authentication endpoint
@app.get("/")
async def autha():
@@ -27,4 +35,70 @@ async def autha():
@app.get("/auth")
async def auth():
return {"dir": key_dir}
return {"dir": key_dir}
# Llamacpp endpoint
class LlamaItem(BaseModel):
prompt: str
model_path: str
temperature: float
top_p: float
top_k: int
max_tokens: int
presence_penalty: float
frequency_penalty: float
repeat_penalty: float
n_ctx: int
stop: List[str]
app.n_ctx = 2000
app.last_model_path = ""
app.llm:Llama = None
def stream_chat_llamacpp(item:LlamaItem):
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
chunks = app.llm.create_completion(
prompt = item.prompt,
temperature = item.temperature,
top_p = item.top_p,
top_k = item.top_k,
max_tokens = item.max_tokens,
presence_penalty = item.presence_penalty,
frequency_penalty = item.frequency_penalty,
repeat_penalty = item.repeat_penalty,
stop=item.stop,
stream=True
)
for chunk in chunks:
cont:CompletionChunk = chunk
encoded = cont["choices"][0]["text"]
print(encoded, end="")
yield encoded
@app.post("/llamacpp")
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
if key != x_risu_auth:
return {"error": "Invalid key"}
return StreamingResponse(stream_chat_llamacpp(item))
class LlamaTokenizeItem(BaseModel):
prompt: str
model_path: str
n_ctx: int
@app.post("/llamacpp/tokenize")
async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
if key != x_risu_auth:
return {"error": "Invalid key"}
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
return app.llm.tokenize(item.prompt.encode('utf-8'))