Add native rust streamed fetch implementation

This commit is contained in:
kwaroran
2024-03-06 21:49:32 +09:00
parent 0117fab49c
commit 826eac60d1
5 changed files with 410 additions and 22 deletions

View File

@@ -4,7 +4,7 @@ import { DataBase, setDatabase, type character } from "../storage/database";
import { pluginProcess } from "../plugins/plugins";
import { language } from "../../lang";
import { stringlizeAINChat, stringlizeChat, stringlizeChatOba, getStopStrings, unstringlizeAIN, unstringlizeChat } from "./stringlize";
import { addFetchLog, globalFetch, isNodeServer, isTauri } from "../storage/globalApi";
import { addFetchLog, fetchNative, globalFetch, isNodeServer, isTauri, textifyReadableStream } from "../storage/globalApi";
import { sleep } from "../util";
import { createDeep } from "./deepai";
import { hubURL } from "../characterCards";
@@ -526,36 +526,24 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model'
}
}
}
const da = (throughProxi)
? await fetch(hubURL + `/proxy2`, {
body: JSON.stringify(body),
headers: {
"risu-header": encodeURIComponent(JSON.stringify(headers)),
"risu-url": encodeURIComponent(replacerURL),
"Content-Type": "application/json",
"x-risu-tk": "use"
},
method: "POST",
signal: abortSignal
})
: await fetch(replacerURL, {
body: JSON.stringify(body),
method: "POST",
headers: headers,
signal: abortSignal
})
const da = await fetchNative(replacerURL, {
body: JSON.stringify(body),
method: "POST",
headers: headers,
signal: abortSignal
})
if(da.status !== 200){
return {
type: "fail",
result: await da.text()
result: await textifyReadableStream(da.body)
}
}
if (!da.headers.get('Content-Type').includes('text/event-stream')){
return {
type: "fail",
result: await da.text()
result: await textifyReadableStream(da.body)
}
}

View File

@@ -27,6 +27,7 @@ import { Capacitor, CapacitorHttp } from '@capacitor/core';
import * as CapFS from '@capacitor/filesystem'
import { save } from "@tauri-apps/api/dialog";
import type { RisuModule } from "../process/modules";
import { listen } from '@tauri-apps/api/event'
//@ts-ignore
export const isTauri = !!window.__TAURI__
@@ -1277,4 +1278,152 @@ export class LocalWriter{
async close(){
await this.writer.close()
}
}
let fetchIndex = 0
let tauriNativeFetchData:{[key:string]:StreamedFetchChunk[]} = {}
interface StreamedFetchChunkData{
type:'chunk',
body:string,
id:string
}
interface StreamedFetchHeaderData{
type:'headers',
body:{[key:string]:string},
id:string,
status:number
}
interface StreamedFetchEndData{
type:'end',
id:string
}
type StreamedFetchChunk = StreamedFetchChunkData|StreamedFetchHeaderData|StreamedFetchEndData
listen('streamed_fetch', (event) => {
try {
const parsed = JSON.parse(event.payload as string)
const id = parsed.id
tauriNativeFetchData[id]?.push(parsed)
} catch (error) {
console.error(error)
}
})
export async function fetchNative(url:string, arg:{
body:string,
headers?:{[key:string]:string},
method?:"POST",
signal?:AbortSignal,
useRisuTk?:boolean
}):Promise<{ body: ReadableStream<Uint8Array>; headers: Headers; status: number }> {
let headers = arg.headers ?? {}
const db = get(DataBase)
let throughProxi = (!isTauri) && (!isNodeServer) && (!db.usePlainFetch) && (!Capacitor.isNativePlatform())
if(isTauri){
fetchIndex++
if(arg.signal && arg.signal.aborted){
throw new Error('aborted')
}
if(fetchIndex >= 100000){
fetchIndex = 0
}
let fetchId = fetchIndex.toString().padStart(5,'0')
tauriNativeFetchData[fetchId] = []
let resolved = false
let error = ''
invoke('streamed_fetch', {
id: fetchId,
url: url,
headers: JSON.stringify(headers),
body: arg.body,
}).then((res) => {
const parsedRes = JSON.parse(res as string)
if(!parsedRes.success){
error = parsedRes.body
resolved = true
}
})
let resHeaders:{[key:string]:string} = null
let status = 400
const readableStream = new ReadableStream<Uint8Array>({
async start(controller) {
while(!resolved || tauriNativeFetchData[fetchId].length > 0){
if(tauriNativeFetchData[fetchId].length > 0){
const data = tauriNativeFetchData[fetchId].shift()
console.log(data)
if(data.type === 'chunk'){
const chunk = Buffer.from(data.body, 'base64')
controller.enqueue(chunk)
}
if(data.type === 'headers'){
resHeaders = data.body
status = data.status
}
if(data.type === 'end'){
resolved = true
}
}
await sleep(10)
}
controller.close()
}
})
while(resHeaders === null && !resolved){
await sleep(10)
}
if(resHeaders === null){
resHeaders = {}
}
if(error !== ''){
throw new Error(error)
}
return {
body: readableStream,
headers: new Headers(resHeaders),
status: status
}
}
else if(throughProxi){
return await fetch(hubURL + `/proxy2`, {
body: arg.body,
headers: arg.useRisuTk ? {
"risu-header": encodeURIComponent(JSON.stringify(headers)),
"risu-url": encodeURIComponent(url),
"Content-Type": "application/json",
"x-risu-tk": "use"
}: {
"risu-header": encodeURIComponent(JSON.stringify(headers)),
"risu-url": encodeURIComponent(url),
"Content-Type": "application/json"
},
method: "POST",
signal: arg.signal
})
}
else{
return await fetch(url, {
body: arg.body,
headers: headers,
method: arg.method,
signal: arg.signal
})
}
}
export function textifyReadableStream(stream:ReadableStream<Uint8Array>){
return new Response(stream).text()
}