From 826eac60d1eca82334c62380ce8020cfadafac63 Mon Sep 17 00:00:00 2001 From: kwaroran Date: Wed, 6 Mar 2024 21:49:32 +0900 Subject: [PATCH] Add native rust streamed fetch implementation --- src-tauri/Cargo.lock | 174 ++++++++++++++++++++++++++++++++++++ src-tauri/Cargo.toml | 2 + src-tauri/src/main.rs | 77 +++++++++++++++- src/ts/process/request.ts | 30 ++----- src/ts/storage/globalApi.ts | 149 ++++++++++++++++++++++++++++++ 5 files changed, 410 insertions(+), 22 deletions(-) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index bd69c783..e5d99645 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -726,6 +726,22 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "eventsource-client" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c80c6714d1a380314fcb11a22eeff022e1e1c9642f0bb54e15dc9cb29f37b29" +dependencies = [ + "futures", + "hyper", + "hyper-rustls", + "hyper-timeout", + "log", + "pin-project", + "rand 0.8.5", + "tokio", +] + [[package]] name = "fancy-regex" version = "0.11.0" @@ -823,6 +839,21 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -830,6 +861,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -884,6 +916,7 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -1346,6 +1379,34 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http", + "hyper", + "log", + "rustls", + "rustls-native-certs", + "tokio", + "tokio-rustls", +] + +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -2200,6 +2261,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -2547,12 +2628,29 @@ dependencies = [ "windows 0.37.0", ] +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.12", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "risuai" version = "0.0.0" dependencies = [ "base64 0.21.7", "darling", + "eventsource-client", + "futures", "reqwest", "serde_json", "tar", @@ -2596,6 +2694,30 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -2605,6 +2727,16 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -2653,6 +2785,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -2925,6 +3067,12 @@ dependencies = [ "system-deps 5.0.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -3469,6 +3617,16 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-native-tls" version = "0.3.1" @@ -3479,6 +3637,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.10" @@ -3676,6 +3844,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.0" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 21a1fdc2..92d1e630 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -21,6 +21,8 @@ reqwest = { version = "0.11.16", features = ["json"] } darling = "0.20.3" zip = "0.6.6" tar = "0.4.40" +eventsource-client = "0.12.2" +futures = "0.3.30" [features] # this feature is used for production builds or when `devPath` points to the filesystem diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index e85bf625..d708daff 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -10,6 +10,7 @@ fn greet(name: &str) -> String { use serde_json::Value; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use base64::{engine::general_purpose, Engine as _}; +use tauri::Manager; use std::io::Write; use std::{time::Duration, path::Path}; use serde_json::json; @@ -374,6 +375,79 @@ fn run_server_local(){ } + +#[tauri::command] +async fn streamed_fetch(id:String, url:String, headers: String, body: String, handle: tauri::AppHandle) -> String { + + //parse headers + let headers_json: Value = match serde_json::from_str(&headers) { + Ok(h) => h, + Err(e) => return format!(r#"{{"success":false, body:{}}}"#, e.to_string()), + }; + let app = handle.app_handle(); + + let mut headers = HeaderMap::new(); + if let Some(obj) = headers_json.as_object() { + for (key, value) in obj { + let header_name = match HeaderName::from_bytes(key.as_bytes()) { + Ok(name) => name, + Err(e) => return format!(r#"{{"success":false, body:{}}}"#, e.to_string()), + }; + let header_value = match HeaderValue::from_str(value.as_str().unwrap_or("")) { + Ok(value) => value, + Err(e) => return format!(r#"{{"success":false, body:{}}}"#, e.to_string()), + }; + headers.insert(header_name, header_value); + } + } else { + return format!(r#"{{"success":false,"body":"Invalid header JSON"}}"#); + } + + let client = reqwest::Client::new(); + let response = client + .post(&url) + .headers(headers) + .timeout(Duration::from_secs(240)) + .body(body) + .send().await; + + match response { + Ok(mut resp) => { + let headers = resp.headers(); + let header_json = header_map_to_json(headers); + app.emit_all("streamed_fetch", &format!(r#"{{"type": "headers", "body": {}, "id": "{}", "status": {}}}"#, header_json, id, resp.status().as_u16())).unwrap(); + loop { + let byt = resp.chunk().await; + match byt { + Ok(chunk) => { + if chunk.is_none() { + break; + } + let chunk = chunk.unwrap(); + let encoded = general_purpose::STANDARD.encode(chunk); + let emited = app.emit_all("streamed_fetch", &format!(r#"{{"type": "chunk", "body": "{}", "id": "{}"}}"#, encoded, id)); + + match emited { + Ok(_) => {}, + Err(e) => { + return format!(r#"{{"success":false, body:{}}}"#, e.to_string()) + } + } + } + Err(e) => { + return format!(r#"{{"success":false, body:{}}}"#, e.to_string()) + } + } + } + app.emit_all("streamed_fetch", &format!(r#"{{"type": "end", "id": "{}"}}"#, id)).unwrap(); + return "{\"success\":true}".to_string(); + } + Err(e) => { + return format!(r#"{{"success":false, body:{}}}"#, e.to_string()) + } + } +} + fn main() { tauri::Builder::default() .invoke_handler(tauri::generate_handler![ @@ -386,7 +460,8 @@ fn main() { install_pip, post_py_install, run_py_server, - install_py_dependencies + install_py_dependencies, + streamed_fetch ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); diff --git a/src/ts/process/request.ts b/src/ts/process/request.ts index 687a17dc..d5584bea 100644 --- a/src/ts/process/request.ts +++ b/src/ts/process/request.ts @@ -4,7 +4,7 @@ import { DataBase, setDatabase, type character } from "../storage/database"; import { pluginProcess } from "../plugins/plugins"; import { language } from "../../lang"; import { stringlizeAINChat, stringlizeChat, stringlizeChatOba, getStopStrings, unstringlizeAIN, unstringlizeChat } from "./stringlize"; -import { addFetchLog, globalFetch, isNodeServer, isTauri } from "../storage/globalApi"; +import { addFetchLog, fetchNative, globalFetch, isNodeServer, isTauri, textifyReadableStream } from "../storage/globalApi"; import { sleep } from "../util"; import { createDeep } from "./deepai"; import { hubURL } from "../characterCards"; @@ -526,36 +526,24 @@ export async function requestChatDataMain(arg:requestDataArgument, model:'model' } } } - const da = (throughProxi) - ? await fetch(hubURL + `/proxy2`, { - body: JSON.stringify(body), - headers: { - "risu-header": encodeURIComponent(JSON.stringify(headers)), - "risu-url": encodeURIComponent(replacerURL), - "Content-Type": "application/json", - "x-risu-tk": "use" - }, - method: "POST", - signal: abortSignal - }) - : await fetch(replacerURL, { - body: JSON.stringify(body), - method: "POST", - headers: headers, - signal: abortSignal - }) + const da = await fetchNative(replacerURL, { + body: JSON.stringify(body), + method: "POST", + headers: headers, + signal: abortSignal + }) if(da.status !== 200){ return { type: "fail", - result: await da.text() + result: await textifyReadableStream(da.body) } } if (!da.headers.get('Content-Type').includes('text/event-stream')){ return { type: "fail", - result: await da.text() + result: await textifyReadableStream(da.body) } } diff --git a/src/ts/storage/globalApi.ts b/src/ts/storage/globalApi.ts index 740f8f33..99a33491 100644 --- a/src/ts/storage/globalApi.ts +++ b/src/ts/storage/globalApi.ts @@ -27,6 +27,7 @@ import { Capacitor, CapacitorHttp } from '@capacitor/core'; import * as CapFS from '@capacitor/filesystem' import { save } from "@tauri-apps/api/dialog"; import type { RisuModule } from "../process/modules"; +import { listen } from '@tauri-apps/api/event' //@ts-ignore export const isTauri = !!window.__TAURI__ @@ -1277,4 +1278,152 @@ export class LocalWriter{ async close(){ await this.writer.close() } +} + +let fetchIndex = 0 +let tauriNativeFetchData:{[key:string]:StreamedFetchChunk[]} = {} + +interface StreamedFetchChunkData{ + type:'chunk', + body:string, + id:string +} + +interface StreamedFetchHeaderData{ + type:'headers', + body:{[key:string]:string}, + id:string, + status:number +} + +interface StreamedFetchEndData{ + type:'end', + id:string +} + +type StreamedFetchChunk = StreamedFetchChunkData|StreamedFetchHeaderData|StreamedFetchEndData + +listen('streamed_fetch', (event) => { + try { + const parsed = JSON.parse(event.payload as string) + const id = parsed.id + tauriNativeFetchData[id]?.push(parsed) + } catch (error) { + console.error(error) + } +}) + +export async function fetchNative(url:string, arg:{ + body:string, + headers?:{[key:string]:string}, + method?:"POST", + signal?:AbortSignal, + useRisuTk?:boolean +}):Promise<{ body: ReadableStream; headers: Headers; status: number }> { + let headers = arg.headers ?? {} + const db = get(DataBase) + let throughProxi = (!isTauri) && (!isNodeServer) && (!db.usePlainFetch) && (!Capacitor.isNativePlatform()) + if(isTauri){ + fetchIndex++ + if(arg.signal && arg.signal.aborted){ + throw new Error('aborted') + } + if(fetchIndex >= 100000){ + fetchIndex = 0 + } + let fetchId = fetchIndex.toString().padStart(5,'0') + tauriNativeFetchData[fetchId] = [] + let resolved = false + + let error = '' + invoke('streamed_fetch', { + id: fetchId, + url: url, + headers: JSON.stringify(headers), + body: arg.body, + }).then((res) => { + const parsedRes = JSON.parse(res as string) + if(!parsedRes.success){ + error = parsedRes.body + resolved = true + } + }) + + let resHeaders:{[key:string]:string} = null + let status = 400 + + const readableStream = new ReadableStream({ + async start(controller) { + while(!resolved || tauriNativeFetchData[fetchId].length > 0){ + if(tauriNativeFetchData[fetchId].length > 0){ + const data = tauriNativeFetchData[fetchId].shift() + console.log(data) + if(data.type === 'chunk'){ + const chunk = Buffer.from(data.body, 'base64') + controller.enqueue(chunk) + } + if(data.type === 'headers'){ + resHeaders = data.body + status = data.status + } + if(data.type === 'end'){ + resolved = true + } + } + await sleep(10) + } + controller.close() + } + }) + + while(resHeaders === null && !resolved){ + await sleep(10) + } + + if(resHeaders === null){ + resHeaders = {} + } + + if(error !== ''){ + throw new Error(error) + } + + + return { + body: readableStream, + headers: new Headers(resHeaders), + status: status + } + + + } + else if(throughProxi){ + return await fetch(hubURL + `/proxy2`, { + body: arg.body, + headers: arg.useRisuTk ? { + "risu-header": encodeURIComponent(JSON.stringify(headers)), + "risu-url": encodeURIComponent(url), + "Content-Type": "application/json", + "x-risu-tk": "use" + }: { + "risu-header": encodeURIComponent(JSON.stringify(headers)), + "risu-url": encodeURIComponent(url), + "Content-Type": "application/json" + }, + method: "POST", + signal: arg.signal + }) + } + else{ + return await fetch(url, { + body: arg.body, + headers: headers, + method: arg.method, + signal: arg.signal + }) + } +} + +export function textifyReadableStream(stream:ReadableStream){ + return new Response(stream).text() } \ No newline at end of file