initial commit
This commit is contained in:
388
crates/rsh-backend/src/ws_op.rs
Normal file
388
crates/rsh-backend/src/ws_op.rs
Normal file
@@ -0,0 +1,388 @@
|
||||
use crate::auth::{find_key, verify_signature};
|
||||
use crate::keys::parse_authorized_keys;
|
||||
use crate::persist;
|
||||
use crate::state::{AppState, AttachSink};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use axum::extract::{State, WebSocketUpgrade};
|
||||
use axum::response::IntoResponse;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use rand::RngCore;
|
||||
use rsh_types::{
|
||||
AttachIOFrame, BackendOpMsg, BackendStubMsg, OpEvent, OpMsg, OpReq, OpResp, SessionRecord,
|
||||
SessionView,
|
||||
};
|
||||
use ssh_key::PublicKey;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub async fn handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| run(socket, state))
|
||||
}
|
||||
|
||||
async fn run(mut socket: WebSocket, state: Arc<AppState>) {
|
||||
let pubkey = match auth_handshake(&mut socket, &state).await {
|
||||
Ok(k) => k,
|
||||
Err(reason) => {
|
||||
let _ = send(&mut socket, &BackendOpMsg::AuthFail { reason }).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
if send(&mut socket, &BackendOpMsg::AuthOk).await.is_err() {
|
||||
return;
|
||||
}
|
||||
tracing::info!(fingerprint = %pubkey.fingerprint(Default::default()), "operator authed");
|
||||
|
||||
let (out_tx, mut out_rx) = mpsc::channel::<BackendOpMsg>(128);
|
||||
let (mut sink, mut stream) = socket.split();
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
while let Some(msg) = out_rx.recv().await {
|
||||
let text = match serde_json::to_string(&msg) {
|
||||
Ok(t) => t,
|
||||
Err(_) => break,
|
||||
};
|
||||
if sink.send(Message::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = sink.close().await;
|
||||
});
|
||||
|
||||
let mut attached: Option<(String, u64, u64)> = None;
|
||||
|
||||
while let Some(Ok(msg)) = stream.next().await {
|
||||
let text = match msg {
|
||||
Message::Text(t) => t,
|
||||
Message::Binary(b) => match String::from_utf8(b) {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
},
|
||||
Message::Close(_) => break,
|
||||
_ => continue,
|
||||
};
|
||||
let op: OpMsg = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!("bad op msg: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let (req_id, body) = match op {
|
||||
OpMsg::Req { id, body } => (id, body),
|
||||
_ => continue,
|
||||
};
|
||||
let resp = handle_req(&state, &out_tx, &mut attached, req_id, body).await;
|
||||
if let Some(r) = resp {
|
||||
if out_tx.send(BackendOpMsg::Resp { id: req_id, body: r }).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((s, c, _)) = attached {
|
||||
if let Some(handle) = state.connections.get(&(s, c)) {
|
||||
let mut a = handle.attach.lock().await;
|
||||
*a = None;
|
||||
}
|
||||
}
|
||||
drop(out_tx);
|
||||
let _ = writer.await;
|
||||
}
|
||||
|
||||
async fn auth_handshake(socket: &mut WebSocket, state: &Arc<AppState>) -> Result<PublicKey, String> {
|
||||
let init = recv_op(socket).await.ok_or_else(|| "no message".to_string())?;
|
||||
let offered_str = match init {
|
||||
OpMsg::AuthInit { pubkey_openssh } => pubkey_openssh,
|
||||
_ => return Err("expected AuthInit".into()),
|
||||
};
|
||||
let offered = PublicKey::from_openssh(&offered_str).map_err(|e| format!("bad pubkey: {e}"))?;
|
||||
let keys = state.authorized_keys.read().await;
|
||||
let matched = find_key(&keys, &offered).cloned();
|
||||
drop(keys);
|
||||
let matched = matched.ok_or_else(|| "key not authorized".to_string())?;
|
||||
|
||||
let mut nonce = [0u8; 32];
|
||||
rand::thread_rng().fill_bytes(&mut nonce);
|
||||
send(socket, &BackendOpMsg::Challenge { nonce }).await.map_err(|e| e.to_string())?;
|
||||
let signed = recv_op(socket).await.ok_or_else(|| "no signature".to_string())?;
|
||||
let (sig_bytes, _alg) = match signed {
|
||||
OpMsg::AuthSign { signature, alg } => (signature, alg),
|
||||
_ => return Err("expected AuthSign".into()),
|
||||
};
|
||||
verify_signature(&matched, &nonce, &sig_bytes).map_err(|e| e.to_string())?;
|
||||
Ok(matched)
|
||||
}
|
||||
|
||||
async fn handle_req(
|
||||
state: &Arc<AppState>,
|
||||
out_tx: &mpsc::Sender<BackendOpMsg>,
|
||||
attached: &mut Option<(String, u64, u64)>,
|
||||
req_id: u64,
|
||||
body: OpReq,
|
||||
) -> Option<OpResp> {
|
||||
match body {
|
||||
OpReq::SessionList => Some(OpResp::Sessions(list_sessions(state).await)),
|
||||
OpReq::SessionCreate { name, password_hash } => {
|
||||
let mut s = state.sessions.write().await;
|
||||
if s.contains_key(&name) {
|
||||
return Some(OpResp::Err(format!("session '{name}' already exists")));
|
||||
}
|
||||
let rec = SessionRecord {
|
||||
id: name.clone(),
|
||||
password_hash,
|
||||
created_at: now_unix(),
|
||||
};
|
||||
s.insert(name.clone(), rec.clone());
|
||||
let snapshot: Vec<_> = s.values().cloned().collect();
|
||||
drop(s);
|
||||
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
let view = SessionView {
|
||||
id: rec.id.clone(),
|
||||
has_password: rec.password_hash.is_some(),
|
||||
created_at: rec.created_at,
|
||||
connection_count: 0,
|
||||
};
|
||||
let _ = state.event_bus.send(OpEvent::NewSession(view));
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::SessionDelete { name, disconnect } => {
|
||||
let mut s = state.sessions.write().await;
|
||||
if s.remove(&name).is_none() {
|
||||
return Some(OpResp::Err(format!("no such session '{name}'")));
|
||||
}
|
||||
let snapshot: Vec<_> = s.values().cloned().collect();
|
||||
drop(s);
|
||||
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
if disconnect {
|
||||
disconnect_session(state, &name).await;
|
||||
}
|
||||
let _ = state.event_bus.send(OpEvent::SessionDeleted { session: name });
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::SessionUpdate { name, set_password_hash, disconnect } => {
|
||||
let mut s = state.sessions.write().await;
|
||||
let Some(rec) = s.get_mut(&name) else {
|
||||
return Some(OpResp::Err(format!("no such session '{name}'")));
|
||||
};
|
||||
if let Some(pw) = set_password_hash {
|
||||
rec.password_hash = pw;
|
||||
}
|
||||
let snapshot: Vec<_> = s.values().cloned().collect();
|
||||
drop(s);
|
||||
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
if disconnect {
|
||||
disconnect_session(state, &name).await;
|
||||
}
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::ConnectionList { session } => {
|
||||
Some(OpResp::Connections(state.list_connections(session.as_deref())))
|
||||
}
|
||||
OpReq::Attach { session, connection_id, pty: _, cols, rows } => {
|
||||
let conn_id = match connection_id {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
let mut found = None;
|
||||
for kv in state.connections.iter() {
|
||||
if kv.key().0 == session {
|
||||
if found.is_some() {
|
||||
return Some(OpResp::Err("multiple connections; specify id".into()));
|
||||
}
|
||||
found = Some(kv.key().1);
|
||||
}
|
||||
}
|
||||
match found {
|
||||
Some(c) => c,
|
||||
None => return Some(OpResp::Err("no connections".into())),
|
||||
}
|
||||
}
|
||||
};
|
||||
let Some(handle) = state.connections.get(&(session.clone(), conn_id)).map(|h| h.clone()) else {
|
||||
return Some(OpResp::Err("connection not found".into()));
|
||||
};
|
||||
{
|
||||
let mut a = handle.attach.lock().await;
|
||||
*a = Some(AttachSink { req_id, sender: out_tx.clone() });
|
||||
}
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
|
||||
*attached = Some((session, conn_id, req_id));
|
||||
Some(OpResp::AttachReady { connection_id: conn_id })
|
||||
}
|
||||
OpReq::AttachIO(frame) => {
|
||||
let Some((session, conn_id, _)) = attached.clone() else {
|
||||
return Some(OpResp::Err("not attached".into()));
|
||||
};
|
||||
let Some(handle) = state.connections.get(&(session, conn_id)).map(|h| h.clone()) else {
|
||||
return Some(OpResp::Err("connection gone".into()));
|
||||
};
|
||||
match frame {
|
||||
AttachIOFrame::Stdin(b) => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Stdin(b)).await;
|
||||
}
|
||||
AttachIOFrame::Resize { cols, rows } => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
|
||||
}
|
||||
AttachIOFrame::Kill => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
|
||||
}
|
||||
AttachIOFrame::Eof => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Stdin(Vec::new())).await;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
OpReq::Detach => {
|
||||
if let Some((s, c, _)) = attached.take() {
|
||||
if let Some(handle) = state.connections.get(&(s, c)) {
|
||||
let mut a = handle.attach.lock().await;
|
||||
*a = None;
|
||||
}
|
||||
}
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::KeysList => {
|
||||
let text = persist::load_authorized_keys_text(&state.cfg.authorized_keys_path())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let keys: Vec<String> = text
|
||||
.lines()
|
||||
.map(|l| l.trim().to_string())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect();
|
||||
Some(OpResp::Keys(keys))
|
||||
}
|
||||
OpReq::KeysAppend { keys } => {
|
||||
let path = state.cfg.authorized_keys_path();
|
||||
let mut text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
|
||||
for k in keys {
|
||||
let k = k.trim();
|
||||
if k.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !text.lines().any(|l| l.trim() == k) {
|
||||
if !text.is_empty() && !text.ends_with('\n') {
|
||||
text.push('\n');
|
||||
}
|
||||
text.push_str(k);
|
||||
text.push('\n');
|
||||
}
|
||||
}
|
||||
if let Err(e) = persist::save_authorized_keys_text(&path, &text).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
reload_authorized_keys(state, &text).await;
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::KeysRemove { keys } => {
|
||||
let path = state.cfg.authorized_keys_path();
|
||||
let text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
|
||||
let targets: Vec<String> = keys.iter().map(|k| k.trim().to_string()).collect();
|
||||
let new: String = text
|
||||
.lines()
|
||||
.filter(|l| {
|
||||
let lt = l.trim();
|
||||
!targets.iter().any(|t| t == lt)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let new = if new.is_empty() { new } else { format!("{}\n", new) };
|
||||
if let Err(e) = persist::save_authorized_keys_text(&path, &new).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
reload_authorized_keys(state, &new).await;
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::KeysReplace { content } => {
|
||||
let path = state.cfg.authorized_keys_path();
|
||||
if let Err(e) = persist::save_authorized_keys_text(&path, &content).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
reload_authorized_keys(state, &content).await;
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::Watch { session } => {
|
||||
let mut rx = state.event_bus.subscribe();
|
||||
let tx = out_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(ev) = rx.recv().await {
|
||||
let pass = match &ev {
|
||||
OpEvent::NewConnection(v) => session.as_ref().map_or(true, |s| &v.session_id == s),
|
||||
OpEvent::ConnectionClosed { session: s, .. } => session.as_ref().map_or(true, |x| x == s),
|
||||
OpEvent::NewSession(v) => session.as_ref().map_or(true, |s| &v.id == s),
|
||||
OpEvent::SessionDeleted { session: s } => session.as_ref().map_or(true, |x| x == s),
|
||||
};
|
||||
if !pass {
|
||||
continue;
|
||||
}
|
||||
if tx.send(BackendOpMsg::Event(ev)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
Some(OpResp::WatchStarted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_sessions(state: &Arc<AppState>) -> Vec<SessionView> {
|
||||
let s = state.sessions.read().await;
|
||||
s.values()
|
||||
.map(|r| SessionView {
|
||||
id: r.id.clone(),
|
||||
has_password: r.password_hash.is_some(),
|
||||
created_at: r.created_at,
|
||||
connection_count: state.connection_count(&r.id),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn disconnect_session(state: &Arc<AppState>, name: &str) {
|
||||
let mut to_kill = Vec::new();
|
||||
for kv in state.connections.iter() {
|
||||
if kv.key().0 == name {
|
||||
to_kill.push((kv.key().clone(), kv.value().clone()));
|
||||
}
|
||||
}
|
||||
for (_, handle) in &to_kill {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn reload_authorized_keys(state: &Arc<AppState>, text: &str) {
|
||||
let parsed = parse_authorized_keys(text);
|
||||
let mut k = state.authorized_keys.write().await;
|
||||
*k = parsed;
|
||||
}
|
||||
|
||||
async fn send(socket: &mut WebSocket, msg: &BackendOpMsg) -> Result<(), axum::Error> {
|
||||
let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?;
|
||||
socket.send(Message::Text(t)).await
|
||||
}
|
||||
|
||||
async fn recv_op(socket: &mut WebSocket) -> Option<OpMsg> {
|
||||
loop {
|
||||
match socket.recv().await? {
|
||||
Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(),
|
||||
Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(),
|
||||
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue,
|
||||
Ok(Message::Close(_)) => return None,
|
||||
Err(_) => return None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn now_unix() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() as i64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
Reference in New Issue
Block a user