126 lines
3.5 KiB
Python
126 lines
3.5 KiB
Python
from fastapi import FastAPI, Header
|
|
from fastapi.responses import StreamingResponse
|
|
from llama_cpp import Llama, CompletionChunk
|
|
from pydantic import BaseModel
|
|
from typing import Annotated, Union, List
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uuid
|
|
import os
|
|
import sys
|
|
|
|
# Write key for authentication
|
|
|
|
app = FastAPI()
|
|
key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
|
|
if not os.path.exists(key_dir):
|
|
f = open(key_dir, 'w')
|
|
f.write(str(uuid.uuid4()))
|
|
f.close()
|
|
f = open(key_dir, 'r')
|
|
key = f.read()
|
|
f.close()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Authentication endpoint
|
|
|
|
@app.get("/")
|
|
async def autha():
|
|
return {"dir": key_dir}
|
|
|
|
@app.get("/auth")
|
|
async def auth():
|
|
return {"dir": key_dir}
|
|
|
|
|
|
# Llamacpp endpoint
|
|
|
|
class LlamaItem(BaseModel):
|
|
prompt: str
|
|
model_path: str
|
|
temperature: float
|
|
top_p: float
|
|
top_k: int
|
|
max_tokens: int
|
|
presence_penalty: float
|
|
frequency_penalty: float
|
|
repeat_penalty: float
|
|
n_ctx: int
|
|
stop: List[str]
|
|
|
|
app.n_ctx = 2000
|
|
app.last_model_path = ""
|
|
app.llm:Llama = None
|
|
|
|
|
|
def stream_chat_llamacpp(item:LlamaItem):
|
|
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
|
|
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
|
|
app.last_model_path = item.model_path
|
|
app.n_ctx = item.n_ctx
|
|
chunks = app.llm.create_completion(
|
|
prompt = item.prompt,
|
|
temperature = item.temperature,
|
|
# top_p = item.top_p,
|
|
# top_k = item.top_k,
|
|
# max_tokens = item.max_tokens,
|
|
# presence_penalty = item.presence_penalty,
|
|
# frequency_penalty = item.frequency_penalty,
|
|
# repeat_penalty = item.repeat_penalty,
|
|
# stop=item.stop,
|
|
stream=False,
|
|
)
|
|
if(type(chunks) == str):
|
|
print(chunks, end="")
|
|
yield chunks
|
|
return
|
|
if(type(chunks) == bytes):
|
|
print(chunks.decode('utf-8'), end="")
|
|
yield chunks.decode('utf-8')
|
|
return
|
|
if(type(chunks) == dict and "choices" in chunks):
|
|
print(chunks["choices"][0]["text"], end="")
|
|
yield chunks["choices"][0]["text"]
|
|
return
|
|
|
|
for chunk in chunks:
|
|
if(type(chunk) == str):
|
|
print(chunk, end="")
|
|
yield chunk
|
|
continue
|
|
if(type(chunk) == bytes):
|
|
print(chunk.decode('utf-8'), end="")
|
|
yield chunk.decode('utf-8')
|
|
continue
|
|
cont:CompletionChunk = chunk
|
|
print(cont)
|
|
encoded = cont["choices"][0]["text"]
|
|
print(encoded, end="")
|
|
yield encoded
|
|
|
|
@app.post("/llamacpp")
|
|
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
|
|
if key != x_risu_auth:
|
|
return {"error": "Invalid key"}
|
|
return StreamingResponse(stream_chat_llamacpp(item))
|
|
|
|
class LlamaTokenizeItem(BaseModel):
|
|
prompt: str
|
|
model_path: str
|
|
n_ctx: int
|
|
|
|
@app.post("/llamacpp/tokenize")
|
|
async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
|
|
if key != x_risu_auth:
|
|
return {"error": "Invalid key"}
|
|
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
|
|
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
|
|
app.last_model_path = item.model_path
|
|
app.n_ctx = item.n_ctx
|
|
|
|
return app.llm.tokenize(item.prompt.encode('utf-8')) |