use crate::auth::{find_key, verify_signature}; use crate::keys::parse_authorized_keys; use crate::persist; use crate::state::{AppState, AttachSink}; use axum::extract::ws::{Message, WebSocket}; use axum::extract::{State, WebSocketUpgrade}; use axum::response::IntoResponse; use futures_util::{SinkExt, StreamExt}; use rand::RngCore; use rsh_types::{ AttachIOFrame, BackendOpMsg, BackendStubMsg, OpEvent, OpMsg, OpReq, OpResp, SessionRecord, SessionView, }; use ssh_key::PublicKey; use std::sync::Arc; use tokio::sync::mpsc; pub async fn handler( ws: WebSocketUpgrade, State(state): State>, ) -> impl IntoResponse { ws.on_upgrade(move |socket| run(socket, state)) } async fn run(mut socket: WebSocket, state: Arc) { let pubkey = match auth_handshake(&mut socket, &state).await { Ok(k) => k, Err(reason) => { let _ = send(&mut socket, &BackendOpMsg::AuthFail { reason }).await; return; } }; if send(&mut socket, &BackendOpMsg::AuthOk).await.is_err() { return; } tracing::info!(fingerprint = %pubkey.fingerprint(Default::default()), "operator authed"); let (out_tx, mut out_rx) = mpsc::channel::(128); let (mut sink, mut stream) = socket.split(); let writer = tokio::spawn(async move { while let Some(msg) = out_rx.recv().await { let text = match serde_json::to_string(&msg) { Ok(t) => t, Err(_) => break, }; if sink.send(Message::Text(text)).await.is_err() { break; } } let _ = sink.close().await; }); let mut attached: Option<(String, u64, u64)> = None; while let Some(Ok(msg)) = stream.next().await { let text = match msg { Message::Text(t) => t, Message::Binary(b) => match String::from_utf8(b) { Ok(s) => s, Err(_) => continue, }, Message::Close(_) => break, _ => continue, }; let op: OpMsg = match serde_json::from_str(&text) { Ok(v) => v, Err(e) => { tracing::warn!("bad op msg: {e}"); continue; } }; let (req_id, body) = match op { OpMsg::Req { id, body } => (id, body), _ => continue, }; let resp = handle_req(&state, &out_tx, &mut attached, req_id, body).await; if let Some(r) = resp { if out_tx.send(BackendOpMsg::Resp { id: req_id, body: r }).await.is_err() { break; } } } if let Some((s, c, _)) = attached { if let Some(handle) = state.connections.get(&(s, c)) { let mut a = handle.attach.lock().await; *a = None; } } drop(out_tx); let _ = writer.await; } async fn auth_handshake(socket: &mut WebSocket, state: &Arc) -> Result { let init = recv_op(socket).await.ok_or_else(|| "no message".to_string())?; let offered_str = match init { OpMsg::AuthInit { pubkey_openssh } => pubkey_openssh, _ => return Err("expected AuthInit".into()), }; let offered = PublicKey::from_openssh(&offered_str).map_err(|e| format!("bad pubkey: {e}"))?; let keys = state.authorized_keys.read().await; let matched = find_key(&keys, &offered).cloned(); drop(keys); let matched = matched.ok_or_else(|| "key not authorized".to_string())?; let mut nonce = [0u8; 32]; rand::thread_rng().fill_bytes(&mut nonce); send(socket, &BackendOpMsg::Challenge { nonce }).await.map_err(|e| e.to_string())?; let signed = recv_op(socket).await.ok_or_else(|| "no signature".to_string())?; let (sig_bytes, _alg) = match signed { OpMsg::AuthSign { signature, alg } => (signature, alg), _ => return Err("expected AuthSign".into()), }; verify_signature(&matched, &nonce, &sig_bytes).map_err(|e| e.to_string())?; Ok(matched) } async fn handle_req( state: &Arc, out_tx: &mpsc::Sender, attached: &mut Option<(String, u64, u64)>, req_id: u64, body: OpReq, ) -> Option { match body { OpReq::SessionList => Some(OpResp::Sessions(list_sessions(state).await)), OpReq::SessionCreate { name, password_hash } => { let mut s = state.sessions.write().await; if s.contains_key(&name) { return Some(OpResp::Err(format!("session '{name}' already exists"))); } let rec = SessionRecord { id: name.clone(), password_hash, created_at: now_unix(), }; s.insert(name.clone(), rec.clone()); let snapshot: Vec<_> = s.values().cloned().collect(); drop(s); if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await { return Some(OpResp::Err(format!("persist: {e}"))); } let view = SessionView { id: rec.id.clone(), has_password: rec.password_hash.is_some(), created_at: rec.created_at, connection_count: 0, }; let _ = state.event_bus.send(OpEvent::NewSession(view)); Some(OpResp::Ok) } OpReq::SessionDelete { name, disconnect } => { let mut s = state.sessions.write().await; if s.remove(&name).is_none() { return Some(OpResp::Err(format!("no such session '{name}'"))); } let snapshot: Vec<_> = s.values().cloned().collect(); drop(s); if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await { return Some(OpResp::Err(format!("persist: {e}"))); } if disconnect { disconnect_session(state, &name).await; } let _ = state.event_bus.send(OpEvent::SessionDeleted { session: name }); Some(OpResp::Ok) } OpReq::SessionUpdate { name, set_password_hash, disconnect } => { let mut s = state.sessions.write().await; let Some(rec) = s.get_mut(&name) else { return Some(OpResp::Err(format!("no such session '{name}'"))); }; if let Some(pw) = set_password_hash { rec.password_hash = pw; } let snapshot: Vec<_> = s.values().cloned().collect(); drop(s); if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await { return Some(OpResp::Err(format!("persist: {e}"))); } if disconnect { disconnect_session(state, &name).await; } Some(OpResp::Ok) } OpReq::ConnectionList { session } => { Some(OpResp::Connections(state.list_connections(session.as_deref()))) } OpReq::Attach { session, connection_id, pty: _, cols, rows } => { let conn_id = match connection_id { Some(c) => c, None => { let mut found = None; for kv in state.connections.iter() { if kv.key().0 == session { if found.is_some() { return Some(OpResp::Err("multiple connections; specify id".into())); } found = Some(kv.key().1); } } match found { Some(c) => c, None => return Some(OpResp::Err("no connections".into())), } } }; let Some(handle) = state.connections.get(&(session.clone(), conn_id)).map(|h| h.clone()) else { return Some(OpResp::Err("connection not found".into())); }; { let mut a = handle.attach.lock().await; *a = Some(AttachSink { req_id, sender: out_tx.clone() }); } let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await; *attached = Some((session, conn_id, req_id)); Some(OpResp::AttachReady { connection_id: conn_id }) } OpReq::AttachIO(frame) => { let Some((session, conn_id, _)) = attached.clone() else { return Some(OpResp::Err("not attached".into())); }; let Some(handle) = state.connections.get(&(session, conn_id)).map(|h| h.clone()) else { return Some(OpResp::Err("connection gone".into())); }; match frame { AttachIOFrame::Stdin(b) => { let _ = handle.to_stub.send(BackendStubMsg::Stdin(b)).await; } AttachIOFrame::Resize { cols, rows } => { let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await; } AttachIOFrame::Kill => { let _ = handle.to_stub.send(BackendStubMsg::Kill).await; } AttachIOFrame::Eof => { let _ = handle.to_stub.send(BackendStubMsg::Stdin(Vec::new())).await; } } None } OpReq::Detach => { if let Some((s, c, _)) = attached.take() { if let Some(handle) = state.connections.get(&(s, c)) { let mut a = handle.attach.lock().await; *a = None; } } Some(OpResp::Ok) } OpReq::KeysList => { let text = persist::load_authorized_keys_text(&state.cfg.authorized_keys_path()) .await .unwrap_or_default(); let keys: Vec = text .lines() .map(|l| l.trim().to_string()) .filter(|l| !l.is_empty()) .collect(); Some(OpResp::Keys(keys)) } OpReq::KeysAppend { keys } => { let path = state.cfg.authorized_keys_path(); let mut text = persist::load_authorized_keys_text(&path).await.unwrap_or_default(); for k in keys { let k = k.trim(); if k.is_empty() { continue; } if !text.lines().any(|l| l.trim() == k) { if !text.is_empty() && !text.ends_with('\n') { text.push('\n'); } text.push_str(k); text.push('\n'); } } if let Err(e) = persist::save_authorized_keys_text(&path, &text).await { return Some(OpResp::Err(format!("persist: {e}"))); } reload_authorized_keys(state, &text).await; Some(OpResp::Ok) } OpReq::KeysRemove { keys } => { let path = state.cfg.authorized_keys_path(); let text = persist::load_authorized_keys_text(&path).await.unwrap_or_default(); let targets: Vec = keys.iter().map(|k| k.trim().to_string()).collect(); let new: String = text .lines() .filter(|l| { let lt = l.trim(); !targets.iter().any(|t| t == lt) }) .collect::>() .join("\n"); let new = if new.is_empty() { new } else { format!("{}\n", new) }; if let Err(e) = persist::save_authorized_keys_text(&path, &new).await { return Some(OpResp::Err(format!("persist: {e}"))); } reload_authorized_keys(state, &new).await; Some(OpResp::Ok) } OpReq::KeysReplace { content } => { let path = state.cfg.authorized_keys_path(); if let Err(e) = persist::save_authorized_keys_text(&path, &content).await { return Some(OpResp::Err(format!("persist: {e}"))); } reload_authorized_keys(state, &content).await; Some(OpResp::Ok) } OpReq::Watch { session } => { let mut rx = state.event_bus.subscribe(); let tx = out_tx.clone(); tokio::spawn(async move { while let Ok(ev) = rx.recv().await { let pass = match &ev { OpEvent::NewConnection(v) => session.as_ref().map_or(true, |s| &v.session_id == s), OpEvent::ConnectionClosed { session: s, .. } => session.as_ref().map_or(true, |x| x == s), OpEvent::NewSession(v) => session.as_ref().map_or(true, |s| &v.id == s), OpEvent::SessionDeleted { session: s } => session.as_ref().map_or(true, |x| x == s), }; if !pass { continue; } if tx.send(BackendOpMsg::Event(ev)).await.is_err() { break; } } }); Some(OpResp::WatchStarted) } } } async fn list_sessions(state: &Arc) -> Vec { let s = state.sessions.read().await; s.values() .map(|r| SessionView { id: r.id.clone(), has_password: r.password_hash.is_some(), created_at: r.created_at, connection_count: state.connection_count(&r.id), }) .collect() } async fn disconnect_session(state: &Arc, name: &str) { let mut to_kill = Vec::new(); for kv in state.connections.iter() { if kv.key().0 == name { to_kill.push((kv.key().clone(), kv.value().clone())); } } for (_, handle) in &to_kill { let _ = handle.to_stub.send(BackendStubMsg::Kill).await; } } async fn reload_authorized_keys(state: &Arc, text: &str) { let parsed = parse_authorized_keys(text); let mut k = state.authorized_keys.write().await; *k = parsed; } async fn send(socket: &mut WebSocket, msg: &BackendOpMsg) -> Result<(), axum::Error> { let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?; socket.send(Message::Text(t)).await } async fn recv_op(socket: &mut WebSocket) -> Option { loop { match socket.recv().await? { Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(), Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(), Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue, Ok(Message::Close(_)) => return None, Err(_) => return None, } } } fn now_unix() -> i64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .map(|d| d.as_secs() as i64) .unwrap_or(0) }