Files
risuai/src-tauri/src-python/main.py
2024-01-16 17:52:33 +09:00

126 lines
3.5 KiB
Python

from fastapi import FastAPI, Header
from fastapi.responses import StreamingResponse
from llama_cpp import Llama, CompletionChunk
from pydantic import BaseModel
from typing import Annotated, Union, List
from fastapi.middleware.cors import CORSMiddleware
import uuid
import os
import sys
# Write key for authentication
app = FastAPI()
key_dir = os.path.join(os.path.dirname(sys.executable), "key.txt")
if not os.path.exists(key_dir):
f = open(key_dir, 'w')
f.write(str(uuid.uuid4()))
f.close()
f = open(key_dir, 'r')
key = f.read()
f.close()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Authentication endpoint
@app.get("/")
async def autha():
return {"dir": key_dir}
@app.get("/auth")
async def auth():
return {"dir": key_dir}
# Llamacpp endpoint
class LlamaItem(BaseModel):
prompt: str
model_path: str
temperature: float
top_p: float
top_k: int
max_tokens: int
presence_penalty: float
frequency_penalty: float
repeat_penalty: float
n_ctx: int
stop: List[str]
app.n_ctx = 2000
app.last_model_path = ""
app.llm:Llama = None
def stream_chat_llamacpp(item:LlamaItem):
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
chunks = app.llm.create_completion(
prompt = item.prompt,
temperature = item.temperature,
# top_p = item.top_p,
# top_k = item.top_k,
# max_tokens = item.max_tokens,
# presence_penalty = item.presence_penalty,
# frequency_penalty = item.frequency_penalty,
# repeat_penalty = item.repeat_penalty,
# stop=item.stop,
stream=False,
)
if(type(chunks) == str):
print(chunks, end="")
yield chunks
return
if(type(chunks) == bytes):
print(chunks.decode('utf-8'), end="")
yield chunks.decode('utf-8')
return
if(type(chunks) == dict and "choices" in chunks):
print(chunks["choices"][0]["text"], end="")
yield chunks["choices"][0]["text"]
return
for chunk in chunks:
if(type(chunk) == str):
print(chunk, end="")
yield chunk
continue
if(type(chunk) == bytes):
print(chunk.decode('utf-8'), end="")
yield chunk.decode('utf-8')
continue
cont:CompletionChunk = chunk
print(cont)
encoded = cont["choices"][0]["text"]
print(encoded, end="")
yield encoded
@app.post("/llamacpp")
async def llamacpp(item:LlamaItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> StreamingResponse:
if key != x_risu_auth:
return {"error": "Invalid key"}
return StreamingResponse(stream_chat_llamacpp(item))
class LlamaTokenizeItem(BaseModel):
prompt: str
model_path: str
n_ctx: int
@app.post("/llamacpp/tokenize")
async def llamacpp_tokenize(item:LlamaTokenizeItem, x_risu_auth: Annotated[Union[str, None], Header()] = None) -> List[int]:
if key != x_risu_auth:
return {"error": "Invalid key"}
if app.last_model_path != item.model_path or app.llm is None or app.n_ctx != item.n_ctx:
app.llm = Llama(model_path=item.model_path, n_ctx=app.n_ctx + 200)
app.last_model_path = item.model_path
app.n_ctx = item.n_ctx
return app.llm.tokenize(item.prompt.encode('utf-8'))