Experimental llamacpp support
This commit is contained in:
@@ -1,31 +0,0 @@
|
||||
from llama_cpp import Llama
|
||||
from pydantic import BaseModel
|
||||
|
||||
class LlamaItem(BaseModel):
|
||||
prompt: str
|
||||
model_path: str
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0,
|
||||
frequency_penalty: float = 0,
|
||||
repeat_penalty: float = 1.1,
|
||||
n_ctx: int = 2000
|
||||
|
||||
def stream_chat_llamacpp(item:LlamaItem):
|
||||
if last_model_path != item.model_path or llm is None or n_ctx != item.n_ctx:
|
||||
llm = Llama(model_path=item.model_path, n_ctx=n_ctx)
|
||||
last_model_path = item.model_path
|
||||
n_ctx = item.n_ctx
|
||||
chunks = llm.create_completion(
|
||||
prompt = item.prompt,
|
||||
)
|
||||
for chunk in chunks:
|
||||
cont = chunk
|
||||
print(cont, end="")
|
||||
yield cont.encode()
|
||||
|
||||
n_ctx = 2000
|
||||
last_model_path = ""
|
||||
llm:Llama
|
||||
@@ -1,12 +1,17 @@
|
||||
from fastapi import FastAPI, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llamacpp import LlamaItem, stream_chat_llamacpp
|
||||
from typing import Annotated, Union
|
||||
from llama_cpp import Llama, CompletionChunk
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated, Union, List
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uuid
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Write key for authentication
|
||||
|
||||
app = FastAPI()
|
||||
key_dir = os.path.join(os.getcwd(), "key.txt")
|
||||
key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
|
||||
if not os.path.exists(key_dir):
|
||||
f = open(key_dir, 'w')
|
||||
f.write(str(uuid.uuid4()))
|
||||
@@ -15,11 +20,14 @@ f = open(key_dir, 'r')
|
||||
key = f.read()
|
||||
f.close()
|
||||
|
||||
@app.post("/llamacpp")
|
||||
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None):
|
||||
if key != x_risu_auth:
|
||||
return {"error": "Invalid key"}
|
||||
return StreamingResponse(stream_chat_llamacpp(item))
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Authentication endpoint
|
||||
|
||||
@app.get("/")
|
||||
async def autha():
|
||||
@@ -27,4 +35,70 @@ async def autha():
|
||||
|
||||
@app.get("/auth")
|
||||
async def auth():
|
||||
return {"dir": key_dir}
|
||||
return {"dir": key_dir}
|
||||
|
||||
|
||||
# Llamacpp endpoint
|
||||
|
||||
class LlamaItem(BaseModel):
|
||||
prompt: str
|
||||
model_path: str
|
||||
temperature: float
|
||||
top_p: float
|
||||
top_k: int
|
||||
max_tokens: int
|
||||
presence_penalty: float
|
||||
frequency_penalty: float
|
||||
repeat_penalty: float
|
||||
n_ctx: int
|
||||
stop: List[str]
|
||||
|
||||
app.n_ctx = 2000
|
||||
app.last_model_path = ""
|
||||
app.llm:Llama = None
|
||||
|
||||
|
||||
def stream_chat_llamacpp(item:LlamaItem):
|
||||
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
|
||||
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
|
||||
app.last_model_path = item.model_path
|
||||
app.n_ctx = item.n_ctx
|
||||
chunks = app.llm.create_completion(
|
||||
prompt = item.prompt,
|
||||
temperature = item.temperature,
|
||||
top_p = item.top_p,
|
||||
top_k = item.top_k,
|
||||
max_tokens = item.max_tokens,
|
||||
presence_penalty = item.presence_penalty,
|
||||
frequency_penalty = item.frequency_penalty,
|
||||
repeat_penalty = item.repeat_penalty,
|
||||
stop=item.stop,
|
||||
stream=True
|
||||
)
|
||||
for chunk in chunks:
|
||||
cont:CompletionChunk = chunk
|
||||
encoded = cont["choices"][0]["text"]
|
||||
print(encoded, end="")
|
||||
yield encoded
|
||||
|
||||
@app.post("/llamacpp")
|
||||
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
|
||||
if key != x_risu_auth:
|
||||
return {"error": "Invalid key"}
|
||||
return StreamingResponse(stream_chat_llamacpp(item))
|
||||
|
||||
class LlamaTokenizeItem(BaseModel):
|
||||
prompt: str
|
||||
model_path: str
|
||||
n_ctx: int
|
||||
|
||||
@app.post("/llamacpp/tokenize")
|
||||
async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
|
||||
if key != x_risu_auth:
|
||||
return {"error": "Invalid key"}
|
||||
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
|
||||
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
|
||||
app.last_model_path = item.model_path
|
||||
app.n_ctx = item.n_ctx
|
||||
|
||||
return app.llm.tokenize(item.prompt.encode('utf-8'))
|
||||
Reference in New Issue
Block a user