initial commit

This commit is contained in:
2026-05-12 21:38:14 +09:00
commit bab9ac8733
42 changed files with 6419 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
[package]
name = "rsh-backend"
version = "0.1.0"
edition.workspace = true
[dependencies]
rsh-types = { workspace = true }
tokio = { workspace = true }
axum = { workspace = true }
tower-http = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
argon2 = { workspace = true }
ssh-key = { workspace = true }
signature = { workspace = true }
dashmap = { workspace = true }
rand = { workspace = true }
bytes = { workspace = true }
futures-util = { workspace = true }
time = { workspace = true }

View File

@@ -0,0 +1,29 @@
use anyhow::{anyhow, Result};
use argon2::password_hash::{PasswordHash, PasswordVerifier};
use argon2::Argon2;
use ssh_key::public::PublicKey;
use ssh_key::SshSig;
pub fn verify_password(password: &str, hash: &str) -> bool {
let parsed = match PasswordHash::new(hash) {
Ok(p) => p,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed)
.is_ok()
}
pub fn verify_signature(pubkey: &PublicKey, nonce: &[u8], signature_blob: &[u8]) -> Result<()> {
let sig = SshSig::from_pem(signature_blob)
.map_err(|e| anyhow!("parse signature: {e}"))?;
pubkey
.verify("rsh-auth", nonce, &sig)
.map_err(|e| anyhow!("verify failed: {e}"))?;
Ok(())
}
pub fn find_key<'a>(keys: &'a [PublicKey], offered: &PublicKey) -> Option<&'a PublicKey> {
let fp = offered.fingerprint(Default::default());
keys.iter().find(|k| k.fingerprint(Default::default()) == fp)
}

View File

@@ -0,0 +1,30 @@
use std::net::SocketAddr;
use std::path::PathBuf;
#[derive(Clone, Debug)]
pub struct Config {
pub data_dir: PathBuf,
pub bind: SocketAddr,
pub log: String,
}
impl Config {
pub fn from_env() -> anyhow::Result<Self> {
let data_dir = std::env::var("RSH_DATA")
.unwrap_or_else(|_| "/var/lib/rsh".to_string())
.into();
let bind: SocketAddr = std::env::var("RSH_BIND")
.unwrap_or_else(|_| "0.0.0.0:7777".to_string())
.parse()?;
let log = std::env::var("RSH_LOG").unwrap_or_else(|_| "info,tower_http=warn".to_string());
Ok(Self { data_dir, bind, log })
}
pub fn sessions_path(&self) -> PathBuf {
self.data_dir.join("sessions.json")
}
pub fn authorized_keys_path(&self) -> PathBuf {
self.data_dir.join("authorized_keys")
}
}

View File

@@ -0,0 +1,10 @@
use ssh_key::PublicKey;
pub fn parse_authorized_keys(content: &str) -> Vec<PublicKey> {
content
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.filter_map(|l| PublicKey::from_openssh(l).ok())
.collect()
}

View File

@@ -0,0 +1,60 @@
mod auth;
mod config;
mod keys;
mod persist;
mod state;
mod ws_op;
mod ws_stub;
use anyhow::Context;
use axum::routing::get;
use axum::Router;
use config::Config;
use state::AppState;
use std::collections::HashMap;
use std::sync::Arc;
use tracing_subscriber::EnvFilter;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cfg = Config::from_env()?;
let filter = EnvFilter::try_new(&cfg.log).unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::fmt().with_env_filter(filter).init();
tokio::fs::create_dir_all(&cfg.data_dir).await.ok();
let state = Arc::new(AppState::new(cfg.clone()));
let sessions = persist::load_sessions(&cfg.sessions_path())
.await
.context("load sessions")?;
{
let mut map = state.sessions.write().await;
let mut h: HashMap<String, _> = HashMap::new();
for s in sessions {
h.insert(s.id.clone(), s);
}
*map = h;
}
let keys_text = persist::load_authorized_keys_text(&cfg.authorized_keys_path())
.await
.unwrap_or_default();
{
let mut k = state.authorized_keys.write().await;
*k = keys::parse_authorized_keys(&keys_text);
tracing::info!(count = k.len(), "loaded authorized keys");
}
let app = Router::new()
.route("/healthz", get(|| async { "ok" }))
.route("/ws/stub", get(ws_stub::handler))
.route("/ws/op", get(ws_op::handler))
.with_state(state.clone())
.layer(tower_http::trace::TraceLayer::new_for_http());
tracing::info!(bind = %cfg.bind, data = ?cfg.data_dir, "rsh-backend listening");
let listener = tokio::net::TcpListener::bind(cfg.bind).await?;
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -0,0 +1,45 @@
use anyhow::Context;
use rsh_types::SessionRecord;
use std::path::Path;
use tokio::fs;
use tokio::io::AsyncWriteExt;
pub async fn load_sessions(path: &Path) -> anyhow::Result<Vec<SessionRecord>> {
if !path.exists() {
return Ok(Vec::new());
}
let data = fs::read(path).await.with_context(|| format!("read {:?}", path))?;
let v: Vec<SessionRecord> = serde_json::from_slice(&data)?;
Ok(v)
}
pub async fn save_sessions(path: &Path, sessions: &[SessionRecord]) -> anyhow::Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await.ok();
}
let tmp = path.with_extension("json.tmp");
let bytes = serde_json::to_vec_pretty(sessions)?;
let mut f = fs::File::create(&tmp).await?;
f.write_all(&bytes).await?;
f.sync_all().await?;
drop(f);
fs::rename(&tmp, path).await?;
Ok(())
}
pub async fn load_authorized_keys_text(path: &Path) -> anyhow::Result<String> {
if !path.exists() {
return Ok(String::new());
}
Ok(fs::read_to_string(path).await?)
}
pub async fn save_authorized_keys_text(path: &Path, content: &str) -> anyhow::Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await.ok();
}
let tmp = path.with_extension("tmp");
fs::write(&tmp, content.as_bytes()).await?;
fs::rename(&tmp, path).await?;
Ok(())
}

View File

@@ -0,0 +1,71 @@
use crate::config::Config;
use dashmap::DashMap;
use rsh_types::{BackendOpMsg, BackendStubMsg, OpEvent, SessionRecord, StubInfo};
use ssh_key::PublicKey;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
pub struct ConnHandle {
pub info: StubInfo,
pub to_stub: mpsc::Sender<BackendStubMsg>,
pub attach: Mutex<Option<AttachSink>>,
pub connected_at: i64,
}
pub struct AttachSink {
pub req_id: u64,
pub sender: mpsc::Sender<BackendOpMsg>,
}
pub struct AppState {
pub cfg: Config,
pub sessions: RwLock<HashMap<String, SessionRecord>>,
pub connections: DashMap<(String, u64), Arc<ConnHandle>>,
pub next_conn_id: DashMap<String, AtomicU64>,
pub authorized_keys: RwLock<Vec<PublicKey>>,
pub event_bus: broadcast::Sender<OpEvent>,
}
impl AppState {
pub fn new(cfg: Config) -> Self {
let (tx, _) = broadcast::channel(256);
Self {
cfg,
sessions: RwLock::new(HashMap::new()),
connections: DashMap::new(),
next_conn_id: DashMap::new(),
authorized_keys: RwLock::new(Vec::new()),
event_bus: tx,
}
}
pub fn alloc_conn_id(&self, session: &str) -> u64 {
let entry = self
.next_conn_id
.entry(session.to_string())
.or_insert_with(|| AtomicU64::new(1));
entry.fetch_add(1, Ordering::Relaxed)
}
pub fn connection_count(&self, session: &str) -> u32 {
self.connections
.iter()
.filter(|kv| kv.key().0 == session)
.count() as u32
}
pub fn list_connections(&self, filter: Option<&str>) -> Vec<rsh_types::ConnectionView> {
self.connections
.iter()
.filter(|kv| filter.map_or(true, |s| kv.key().0 == s))
.map(|kv| rsh_types::ConnectionView {
session_id: kv.key().0.clone(),
connection_id: kv.key().1,
info: kv.value().info.clone(),
connected_at: kv.value().connected_at,
})
.collect()
}
}

View File

@@ -0,0 +1,388 @@
use crate::auth::{find_key, verify_signature};
use crate::keys::parse_authorized_keys;
use crate::persist;
use crate::state::{AppState, AttachSink};
use axum::extract::ws::{Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use futures_util::{SinkExt, StreamExt};
use rand::RngCore;
use rsh_types::{
AttachIOFrame, BackendOpMsg, BackendStubMsg, OpEvent, OpMsg, OpReq, OpResp, SessionRecord,
SessionView,
};
use ssh_key::PublicKey;
use std::sync::Arc;
use tokio::sync::mpsc;
pub async fn handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| run(socket, state))
}
async fn run(mut socket: WebSocket, state: Arc<AppState>) {
let pubkey = match auth_handshake(&mut socket, &state).await {
Ok(k) => k,
Err(reason) => {
let _ = send(&mut socket, &BackendOpMsg::AuthFail { reason }).await;
return;
}
};
if send(&mut socket, &BackendOpMsg::AuthOk).await.is_err() {
return;
}
tracing::info!(fingerprint = %pubkey.fingerprint(Default::default()), "operator authed");
let (out_tx, mut out_rx) = mpsc::channel::<BackendOpMsg>(128);
let (mut sink, mut stream) = socket.split();
let writer = tokio::spawn(async move {
while let Some(msg) = out_rx.recv().await {
let text = match serde_json::to_string(&msg) {
Ok(t) => t,
Err(_) => break,
};
if sink.send(Message::Text(text)).await.is_err() {
break;
}
}
let _ = sink.close().await;
});
let mut attached: Option<(String, u64, u64)> = None;
while let Some(Ok(msg)) = stream.next().await {
let text = match msg {
Message::Text(t) => t,
Message::Binary(b) => match String::from_utf8(b) {
Ok(s) => s,
Err(_) => continue,
},
Message::Close(_) => break,
_ => continue,
};
let op: OpMsg = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
tracing::warn!("bad op msg: {e}");
continue;
}
};
let (req_id, body) = match op {
OpMsg::Req { id, body } => (id, body),
_ => continue,
};
let resp = handle_req(&state, &out_tx, &mut attached, req_id, body).await;
if let Some(r) = resp {
if out_tx.send(BackendOpMsg::Resp { id: req_id, body: r }).await.is_err() {
break;
}
}
}
if let Some((s, c, _)) = attached {
if let Some(handle) = state.connections.get(&(s, c)) {
let mut a = handle.attach.lock().await;
*a = None;
}
}
drop(out_tx);
let _ = writer.await;
}
async fn auth_handshake(socket: &mut WebSocket, state: &Arc<AppState>) -> Result<PublicKey, String> {
let init = recv_op(socket).await.ok_or_else(|| "no message".to_string())?;
let offered_str = match init {
OpMsg::AuthInit { pubkey_openssh } => pubkey_openssh,
_ => return Err("expected AuthInit".into()),
};
let offered = PublicKey::from_openssh(&offered_str).map_err(|e| format!("bad pubkey: {e}"))?;
let keys = state.authorized_keys.read().await;
let matched = find_key(&keys, &offered).cloned();
drop(keys);
let matched = matched.ok_or_else(|| "key not authorized".to_string())?;
let mut nonce = [0u8; 32];
rand::thread_rng().fill_bytes(&mut nonce);
send(socket, &BackendOpMsg::Challenge { nonce }).await.map_err(|e| e.to_string())?;
let signed = recv_op(socket).await.ok_or_else(|| "no signature".to_string())?;
let (sig_bytes, _alg) = match signed {
OpMsg::AuthSign { signature, alg } => (signature, alg),
_ => return Err("expected AuthSign".into()),
};
verify_signature(&matched, &nonce, &sig_bytes).map_err(|e| e.to_string())?;
Ok(matched)
}
async fn handle_req(
state: &Arc<AppState>,
out_tx: &mpsc::Sender<BackendOpMsg>,
attached: &mut Option<(String, u64, u64)>,
req_id: u64,
body: OpReq,
) -> Option<OpResp> {
match body {
OpReq::SessionList => Some(OpResp::Sessions(list_sessions(state).await)),
OpReq::SessionCreate { name, password_hash } => {
let mut s = state.sessions.write().await;
if s.contains_key(&name) {
return Some(OpResp::Err(format!("session '{name}' already exists")));
}
let rec = SessionRecord {
id: name.clone(),
password_hash,
created_at: now_unix(),
};
s.insert(name.clone(), rec.clone());
let snapshot: Vec<_> = s.values().cloned().collect();
drop(s);
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
let view = SessionView {
id: rec.id.clone(),
has_password: rec.password_hash.is_some(),
created_at: rec.created_at,
connection_count: 0,
};
let _ = state.event_bus.send(OpEvent::NewSession(view));
Some(OpResp::Ok)
}
OpReq::SessionDelete { name, disconnect } => {
let mut s = state.sessions.write().await;
if s.remove(&name).is_none() {
return Some(OpResp::Err(format!("no such session '{name}'")));
}
let snapshot: Vec<_> = s.values().cloned().collect();
drop(s);
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
if disconnect {
disconnect_session(state, &name).await;
}
let _ = state.event_bus.send(OpEvent::SessionDeleted { session: name });
Some(OpResp::Ok)
}
OpReq::SessionUpdate { name, set_password_hash, disconnect } => {
let mut s = state.sessions.write().await;
let Some(rec) = s.get_mut(&name) else {
return Some(OpResp::Err(format!("no such session '{name}'")));
};
if let Some(pw) = set_password_hash {
rec.password_hash = pw;
}
let snapshot: Vec<_> = s.values().cloned().collect();
drop(s);
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
if disconnect {
disconnect_session(state, &name).await;
}
Some(OpResp::Ok)
}
OpReq::ConnectionList { session } => {
Some(OpResp::Connections(state.list_connections(session.as_deref())))
}
OpReq::Attach { session, connection_id, pty: _, cols, rows } => {
let conn_id = match connection_id {
Some(c) => c,
None => {
let mut found = None;
for kv in state.connections.iter() {
if kv.key().0 == session {
if found.is_some() {
return Some(OpResp::Err("multiple connections; specify id".into()));
}
found = Some(kv.key().1);
}
}
match found {
Some(c) => c,
None => return Some(OpResp::Err("no connections".into())),
}
}
};
let Some(handle) = state.connections.get(&(session.clone(), conn_id)).map(|h| h.clone()) else {
return Some(OpResp::Err("connection not found".into()));
};
{
let mut a = handle.attach.lock().await;
*a = Some(AttachSink { req_id, sender: out_tx.clone() });
}
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
*attached = Some((session, conn_id, req_id));
Some(OpResp::AttachReady { connection_id: conn_id })
}
OpReq::AttachIO(frame) => {
let Some((session, conn_id, _)) = attached.clone() else {
return Some(OpResp::Err("not attached".into()));
};
let Some(handle) = state.connections.get(&(session, conn_id)).map(|h| h.clone()) else {
return Some(OpResp::Err("connection gone".into()));
};
match frame {
AttachIOFrame::Stdin(b) => {
let _ = handle.to_stub.send(BackendStubMsg::Stdin(b)).await;
}
AttachIOFrame::Resize { cols, rows } => {
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
}
AttachIOFrame::Kill => {
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
}
AttachIOFrame::Eof => {
let _ = handle.to_stub.send(BackendStubMsg::Stdin(Vec::new())).await;
}
}
None
}
OpReq::Detach => {
if let Some((s, c, _)) = attached.take() {
if let Some(handle) = state.connections.get(&(s, c)) {
let mut a = handle.attach.lock().await;
*a = None;
}
}
Some(OpResp::Ok)
}
OpReq::KeysList => {
let text = persist::load_authorized_keys_text(&state.cfg.authorized_keys_path())
.await
.unwrap_or_default();
let keys: Vec<String> = text
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
Some(OpResp::Keys(keys))
}
OpReq::KeysAppend { keys } => {
let path = state.cfg.authorized_keys_path();
let mut text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
for k in keys {
let k = k.trim();
if k.is_empty() {
continue;
}
if !text.lines().any(|l| l.trim() == k) {
if !text.is_empty() && !text.ends_with('\n') {
text.push('\n');
}
text.push_str(k);
text.push('\n');
}
}
if let Err(e) = persist::save_authorized_keys_text(&path, &text).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
reload_authorized_keys(state, &text).await;
Some(OpResp::Ok)
}
OpReq::KeysRemove { keys } => {
let path = state.cfg.authorized_keys_path();
let text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
let targets: Vec<String> = keys.iter().map(|k| k.trim().to_string()).collect();
let new: String = text
.lines()
.filter(|l| {
let lt = l.trim();
!targets.iter().any(|t| t == lt)
})
.collect::<Vec<_>>()
.join("\n");
let new = if new.is_empty() { new } else { format!("{}\n", new) };
if let Err(e) = persist::save_authorized_keys_text(&path, &new).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
reload_authorized_keys(state, &new).await;
Some(OpResp::Ok)
}
OpReq::KeysReplace { content } => {
let path = state.cfg.authorized_keys_path();
if let Err(e) = persist::save_authorized_keys_text(&path, &content).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
reload_authorized_keys(state, &content).await;
Some(OpResp::Ok)
}
OpReq::Watch { session } => {
let mut rx = state.event_bus.subscribe();
let tx = out_tx.clone();
tokio::spawn(async move {
while let Ok(ev) = rx.recv().await {
let pass = match &ev {
OpEvent::NewConnection(v) => session.as_ref().map_or(true, |s| &v.session_id == s),
OpEvent::ConnectionClosed { session: s, .. } => session.as_ref().map_or(true, |x| x == s),
OpEvent::NewSession(v) => session.as_ref().map_or(true, |s| &v.id == s),
OpEvent::SessionDeleted { session: s } => session.as_ref().map_or(true, |x| x == s),
};
if !pass {
continue;
}
if tx.send(BackendOpMsg::Event(ev)).await.is_err() {
break;
}
}
});
Some(OpResp::WatchStarted)
}
}
}
async fn list_sessions(state: &Arc<AppState>) -> Vec<SessionView> {
let s = state.sessions.read().await;
s.values()
.map(|r| SessionView {
id: r.id.clone(),
has_password: r.password_hash.is_some(),
created_at: r.created_at,
connection_count: state.connection_count(&r.id),
})
.collect()
}
async fn disconnect_session(state: &Arc<AppState>, name: &str) {
let mut to_kill = Vec::new();
for kv in state.connections.iter() {
if kv.key().0 == name {
to_kill.push((kv.key().clone(), kv.value().clone()));
}
}
for (_, handle) in &to_kill {
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
}
}
async fn reload_authorized_keys(state: &Arc<AppState>, text: &str) {
let parsed = parse_authorized_keys(text);
let mut k = state.authorized_keys.write().await;
*k = parsed;
}
async fn send(socket: &mut WebSocket, msg: &BackendOpMsg) -> Result<(), axum::Error> {
let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?;
socket.send(Message::Text(t)).await
}
async fn recv_op(socket: &mut WebSocket) -> Option<OpMsg> {
loop {
match socket.recv().await? {
Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(),
Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(),
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue,
Ok(Message::Close(_)) => return None,
Err(_) => return None,
}
}
}
fn now_unix() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}

View File

@@ -0,0 +1,155 @@
use crate::auth::verify_password;
use crate::state::{AppState, ConnHandle};
use axum::extract::ws::{Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use futures_util::{SinkExt, StreamExt};
use rsh_types::{BackendOpMsg, BackendStubMsg, ConnectionView, OpEvent, OpResp, StubMsg};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
pub async fn handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| run(socket, state))
}
async fn run(mut socket: WebSocket, state: Arc<AppState>) {
let hello = match recv_text::<StubMsg>(&mut socket).await {
Some(StubMsg::Hello { session_id, password, info }) => (session_id, password, info),
_ => {
let _ = send(&mut socket, &BackendStubMsg::Rejected { reason: "expected Hello".into() }).await;
return;
}
};
let (session_id, password, info) = hello;
let sessions = state.sessions.read().await;
let session = match sessions.get(&session_id) {
Some(s) => s.clone(),
None => {
drop(sessions);
let _ = send(&mut socket, &BackendStubMsg::Rejected { reason: "no such session".into() }).await;
return;
}
};
drop(sessions);
if let Some(hash) = &session.password_hash {
let provided = password.unwrap_or_default();
if !verify_password(&provided, hash) {
let _ = send(&mut socket, &BackendStubMsg::Rejected { reason: "bad password".into() }).await;
return;
}
}
let conn_id = state.alloc_conn_id(&session_id);
let (to_stub_tx, mut to_stub_rx) = mpsc::channel::<BackendStubMsg>(64);
let connected_at = now_unix();
let handle = Arc::new(ConnHandle {
info: info.clone(),
to_stub: to_stub_tx.clone(),
attach: Mutex::new(None),
connected_at,
});
state.connections.insert((session_id.clone(), conn_id), handle.clone());
let _ = state.event_bus.send(OpEvent::NewConnection(ConnectionView {
session_id: session_id.clone(),
connection_id: conn_id,
info: info.clone(),
connected_at,
}));
if send(&mut socket, &BackendStubMsg::Accepted { connection_id: conn_id }).await.is_err() {
cleanup(&state, &session_id, conn_id).await;
return;
}
let (mut ws_sink, mut ws_stream) = socket.split();
let writer = tokio::spawn(async move {
while let Some(msg) = to_stub_rx.recv().await {
let text = match serde_json::to_string(&msg) {
Ok(t) => t,
Err(_) => break,
};
if ws_sink.send(Message::Text(text)).await.is_err() {
break;
}
}
let _ = ws_sink.close().await;
});
let state_r = state.clone();
let session_r = session_id.clone();
let handle_r = handle.clone();
let reader = tokio::spawn(async move {
while let Some(Ok(msg)) = ws_stream.next().await {
let text = match msg {
Message::Text(t) => t,
Message::Binary(b) => match String::from_utf8(b) {
Ok(s) => s,
Err(_) => continue,
},
Message::Close(_) => break,
_ => continue,
};
let parsed: StubMsg = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue,
};
match parsed {
StubMsg::Stdout(b) => forward_op(&handle_r, OpResp::Stdout(b)).await,
StubMsg::Stderr(b) => forward_op(&handle_r, OpResp::Stderr(b)).await,
StubMsg::Exited { code } => {
forward_op(&handle_r, OpResp::Exited { code }).await;
break;
}
StubMsg::Pong => {}
StubMsg::Hello { .. } => {}
}
}
cleanup(&state_r, &session_r, conn_id).await;
});
let _ = tokio::join!(writer, reader);
}
async fn forward_op(handle: &Arc<crate::state::ConnHandle>, resp: OpResp) {
let attach = handle.attach.lock().await;
if let Some(sink) = attach.as_ref() {
let _ = sink
.sender
.send(BackendOpMsg::Resp { id: sink.req_id, body: resp })
.await;
}
}
async fn cleanup(state: &Arc<AppState>, session_id: &str, conn_id: u64) {
state.connections.remove(&(session_id.to_string(), conn_id));
let _ = state.event_bus.send(OpEvent::ConnectionClosed {
session: session_id.to_string(),
connection_id: conn_id,
});
}
async fn send(socket: &mut WebSocket, msg: &BackendStubMsg) -> Result<(), axum::Error> {
let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?;
socket.send(Message::Text(t)).await
}
async fn recv_text<T: serde::de::DeserializeOwned>(socket: &mut WebSocket) -> Option<T> {
loop {
match socket.recv().await? {
Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(),
Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(),
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue,
Ok(Message::Close(_)) => return None,
Err(_) => return None,
}
}
}
fn now_unix() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}