Files
rsh/crates/rsh-backend/src/ws_op.rs
2026-05-12 21:38:14 +09:00

389 lines
15 KiB
Rust

use crate::auth::{find_key, verify_signature};
use crate::keys::parse_authorized_keys;
use crate::persist;
use crate::state::{AppState, AttachSink};
use axum::extract::ws::{Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use futures_util::{SinkExt, StreamExt};
use rand::RngCore;
use rsh_types::{
AttachIOFrame, BackendOpMsg, BackendStubMsg, OpEvent, OpMsg, OpReq, OpResp, SessionRecord,
SessionView,
};
use ssh_key::PublicKey;
use std::sync::Arc;
use tokio::sync::mpsc;
pub async fn handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| run(socket, state))
}
async fn run(mut socket: WebSocket, state: Arc<AppState>) {
let pubkey = match auth_handshake(&mut socket, &state).await {
Ok(k) => k,
Err(reason) => {
let _ = send(&mut socket, &BackendOpMsg::AuthFail { reason }).await;
return;
}
};
if send(&mut socket, &BackendOpMsg::AuthOk).await.is_err() {
return;
}
tracing::info!(fingerprint = %pubkey.fingerprint(Default::default()), "operator authed");
let (out_tx, mut out_rx) = mpsc::channel::<BackendOpMsg>(128);
let (mut sink, mut stream) = socket.split();
let writer = tokio::spawn(async move {
while let Some(msg) = out_rx.recv().await {
let text = match serde_json::to_string(&msg) {
Ok(t) => t,
Err(_) => break,
};
if sink.send(Message::Text(text)).await.is_err() {
break;
}
}
let _ = sink.close().await;
});
let mut attached: Option<(String, u64, u64)> = None;
while let Some(Ok(msg)) = stream.next().await {
let text = match msg {
Message::Text(t) => t,
Message::Binary(b) => match String::from_utf8(b) {
Ok(s) => s,
Err(_) => continue,
},
Message::Close(_) => break,
_ => continue,
};
let op: OpMsg = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
tracing::warn!("bad op msg: {e}");
continue;
}
};
let (req_id, body) = match op {
OpMsg::Req { id, body } => (id, body),
_ => continue,
};
let resp = handle_req(&state, &out_tx, &mut attached, req_id, body).await;
if let Some(r) = resp {
if out_tx.send(BackendOpMsg::Resp { id: req_id, body: r }).await.is_err() {
break;
}
}
}
if let Some((s, c, _)) = attached {
if let Some(handle) = state.connections.get(&(s, c)) {
let mut a = handle.attach.lock().await;
*a = None;
}
}
drop(out_tx);
let _ = writer.await;
}
async fn auth_handshake(socket: &mut WebSocket, state: &Arc<AppState>) -> Result<PublicKey, String> {
let init = recv_op(socket).await.ok_or_else(|| "no message".to_string())?;
let offered_str = match init {
OpMsg::AuthInit { pubkey_openssh } => pubkey_openssh,
_ => return Err("expected AuthInit".into()),
};
let offered = PublicKey::from_openssh(&offered_str).map_err(|e| format!("bad pubkey: {e}"))?;
let keys = state.authorized_keys.read().await;
let matched = find_key(&keys, &offered).cloned();
drop(keys);
let matched = matched.ok_or_else(|| "key not authorized".to_string())?;
let mut nonce = [0u8; 32];
rand::thread_rng().fill_bytes(&mut nonce);
send(socket, &BackendOpMsg::Challenge { nonce }).await.map_err(|e| e.to_string())?;
let signed = recv_op(socket).await.ok_or_else(|| "no signature".to_string())?;
let (sig_bytes, _alg) = match signed {
OpMsg::AuthSign { signature, alg } => (signature, alg),
_ => return Err("expected AuthSign".into()),
};
verify_signature(&matched, &nonce, &sig_bytes).map_err(|e| e.to_string())?;
Ok(matched)
}
async fn handle_req(
state: &Arc<AppState>,
out_tx: &mpsc::Sender<BackendOpMsg>,
attached: &mut Option<(String, u64, u64)>,
req_id: u64,
body: OpReq,
) -> Option<OpResp> {
match body {
OpReq::SessionList => Some(OpResp::Sessions(list_sessions(state).await)),
OpReq::SessionCreate { name, password_hash } => {
let mut s = state.sessions.write().await;
if s.contains_key(&name) {
return Some(OpResp::Err(format!("session '{name}' already exists")));
}
let rec = SessionRecord {
id: name.clone(),
password_hash,
created_at: now_unix(),
};
s.insert(name.clone(), rec.clone());
let snapshot: Vec<_> = s.values().cloned().collect();
drop(s);
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
let view = SessionView {
id: rec.id.clone(),
has_password: rec.password_hash.is_some(),
created_at: rec.created_at,
connection_count: 0,
};
let _ = state.event_bus.send(OpEvent::NewSession(view));
Some(OpResp::Ok)
}
OpReq::SessionDelete { name, disconnect } => {
let mut s = state.sessions.write().await;
if s.remove(&name).is_none() {
return Some(OpResp::Err(format!("no such session '{name}'")));
}
let snapshot: Vec<_> = s.values().cloned().collect();
drop(s);
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
if disconnect {
disconnect_session(state, &name).await;
}
let _ = state.event_bus.send(OpEvent::SessionDeleted { session: name });
Some(OpResp::Ok)
}
OpReq::SessionUpdate { name, set_password_hash, disconnect } => {
let mut s = state.sessions.write().await;
let Some(rec) = s.get_mut(&name) else {
return Some(OpResp::Err(format!("no such session '{name}'")));
};
if let Some(pw) = set_password_hash {
rec.password_hash = pw;
}
let snapshot: Vec<_> = s.values().cloned().collect();
drop(s);
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
if disconnect {
disconnect_session(state, &name).await;
}
Some(OpResp::Ok)
}
OpReq::ConnectionList { session } => {
Some(OpResp::Connections(state.list_connections(session.as_deref())))
}
OpReq::Attach { session, connection_id, pty: _, cols, rows } => {
let conn_id = match connection_id {
Some(c) => c,
None => {
let mut found = None;
for kv in state.connections.iter() {
if kv.key().0 == session {
if found.is_some() {
return Some(OpResp::Err("multiple connections; specify id".into()));
}
found = Some(kv.key().1);
}
}
match found {
Some(c) => c,
None => return Some(OpResp::Err("no connections".into())),
}
}
};
let Some(handle) = state.connections.get(&(session.clone(), conn_id)).map(|h| h.clone()) else {
return Some(OpResp::Err("connection not found".into()));
};
{
let mut a = handle.attach.lock().await;
*a = Some(AttachSink { req_id, sender: out_tx.clone() });
}
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
*attached = Some((session, conn_id, req_id));
Some(OpResp::AttachReady { connection_id: conn_id })
}
OpReq::AttachIO(frame) => {
let Some((session, conn_id, _)) = attached.clone() else {
return Some(OpResp::Err("not attached".into()));
};
let Some(handle) = state.connections.get(&(session, conn_id)).map(|h| h.clone()) else {
return Some(OpResp::Err("connection gone".into()));
};
match frame {
AttachIOFrame::Stdin(b) => {
let _ = handle.to_stub.send(BackendStubMsg::Stdin(b)).await;
}
AttachIOFrame::Resize { cols, rows } => {
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
}
AttachIOFrame::Kill => {
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
}
AttachIOFrame::Eof => {
let _ = handle.to_stub.send(BackendStubMsg::Stdin(Vec::new())).await;
}
}
None
}
OpReq::Detach => {
if let Some((s, c, _)) = attached.take() {
if let Some(handle) = state.connections.get(&(s, c)) {
let mut a = handle.attach.lock().await;
*a = None;
}
}
Some(OpResp::Ok)
}
OpReq::KeysList => {
let text = persist::load_authorized_keys_text(&state.cfg.authorized_keys_path())
.await
.unwrap_or_default();
let keys: Vec<String> = text
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
Some(OpResp::Keys(keys))
}
OpReq::KeysAppend { keys } => {
let path = state.cfg.authorized_keys_path();
let mut text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
for k in keys {
let k = k.trim();
if k.is_empty() {
continue;
}
if !text.lines().any(|l| l.trim() == k) {
if !text.is_empty() && !text.ends_with('\n') {
text.push('\n');
}
text.push_str(k);
text.push('\n');
}
}
if let Err(e) = persist::save_authorized_keys_text(&path, &text).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
reload_authorized_keys(state, &text).await;
Some(OpResp::Ok)
}
OpReq::KeysRemove { keys } => {
let path = state.cfg.authorized_keys_path();
let text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
let targets: Vec<String> = keys.iter().map(|k| k.trim().to_string()).collect();
let new: String = text
.lines()
.filter(|l| {
let lt = l.trim();
!targets.iter().any(|t| t == lt)
})
.collect::<Vec<_>>()
.join("\n");
let new = if new.is_empty() { new } else { format!("{}\n", new) };
if let Err(e) = persist::save_authorized_keys_text(&path, &new).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
reload_authorized_keys(state, &new).await;
Some(OpResp::Ok)
}
OpReq::KeysReplace { content } => {
let path = state.cfg.authorized_keys_path();
if let Err(e) = persist::save_authorized_keys_text(&path, &content).await {
return Some(OpResp::Err(format!("persist: {e}")));
}
reload_authorized_keys(state, &content).await;
Some(OpResp::Ok)
}
OpReq::Watch { session } => {
let mut rx = state.event_bus.subscribe();
let tx = out_tx.clone();
tokio::spawn(async move {
while let Ok(ev) = rx.recv().await {
let pass = match &ev {
OpEvent::NewConnection(v) => session.as_ref().map_or(true, |s| &v.session_id == s),
OpEvent::ConnectionClosed { session: s, .. } => session.as_ref().map_or(true, |x| x == s),
OpEvent::NewSession(v) => session.as_ref().map_or(true, |s| &v.id == s),
OpEvent::SessionDeleted { session: s } => session.as_ref().map_or(true, |x| x == s),
};
if !pass {
continue;
}
if tx.send(BackendOpMsg::Event(ev)).await.is_err() {
break;
}
}
});
Some(OpResp::WatchStarted)
}
}
}
async fn list_sessions(state: &Arc<AppState>) -> Vec<SessionView> {
let s = state.sessions.read().await;
s.values()
.map(|r| SessionView {
id: r.id.clone(),
has_password: r.password_hash.is_some(),
created_at: r.created_at,
connection_count: state.connection_count(&r.id),
})
.collect()
}
async fn disconnect_session(state: &Arc<AppState>, name: &str) {
let mut to_kill = Vec::new();
for kv in state.connections.iter() {
if kv.key().0 == name {
to_kill.push((kv.key().clone(), kv.value().clone()));
}
}
for (_, handle) in &to_kill {
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
}
}
async fn reload_authorized_keys(state: &Arc<AppState>, text: &str) {
let parsed = parse_authorized_keys(text);
let mut k = state.authorized_keys.write().await;
*k = parsed;
}
async fn send(socket: &mut WebSocket, msg: &BackendOpMsg) -> Result<(), axum::Error> {
let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?;
socket.send(Message::Text(t)).await
}
async fn recv_op(socket: &mut WebSocket) -> Option<OpMsg> {
loop {
match socket.recv().await? {
Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(),
Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(),
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue,
Ok(Message::Close(_)) => return None,
Err(_) => return None,
}
}
}
fn now_unix() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}