Refactor code for local models
This commit is contained in:
@@ -66,17 +66,39 @@ def stream_chat_llamacpp(item:LlamaItem):
|
||||
chunks = app.llm.create_completion(
|
||||
prompt = item.prompt,
|
||||
temperature = item.temperature,
|
||||
top_p = item.top_p,
|
||||
top_k = item.top_k,
|
||||
max_tokens = item.max_tokens,
|
||||
presence_penalty = item.presence_penalty,
|
||||
frequency_penalty = item.frequency_penalty,
|
||||
repeat_penalty = item.repeat_penalty,
|
||||
stop=item.stop,
|
||||
stream=True
|
||||
# top_p = item.top_p,
|
||||
# top_k = item.top_k,
|
||||
# max_tokens = item.max_tokens,
|
||||
# presence_penalty = item.presence_penalty,
|
||||
# frequency_penalty = item.frequency_penalty,
|
||||
# repeat_penalty = item.repeat_penalty,
|
||||
# stop=item.stop,
|
||||
stream=False,
|
||||
)
|
||||
if(type(chunks) == str):
|
||||
print(chunks, end="")
|
||||
yield chunks
|
||||
return
|
||||
if(type(chunks) == bytes):
|
||||
print(chunks.decode('utf-8'), end="")
|
||||
yield chunks.decode('utf-8')
|
||||
return
|
||||
if(type(chunks) == dict and "choices" in chunks):
|
||||
print(chunks["choices"][0]["text"], end="")
|
||||
yield chunks["choices"][0]["text"]
|
||||
return
|
||||
|
||||
for chunk in chunks:
|
||||
if(type(chunk) == str):
|
||||
print(chunk, end="")
|
||||
yield chunk
|
||||
continue
|
||||
if(type(chunk) == bytes):
|
||||
print(chunk.decode('utf-8'), end="")
|
||||
yield chunk.decode('utf-8')
|
||||
continue
|
||||
cont:CompletionChunk = chunk
|
||||
print(cont)
|
||||
encoded = cont["choices"][0]["text"]
|
||||
print(encoded, end="")
|
||||
yield encoded
|
||||
|
||||
Reference in New Issue
Block a user