389 lines
15 KiB
Rust
389 lines
15 KiB
Rust
use crate::auth::{find_key, verify_signature};
|
|
use crate::keys::parse_authorized_keys;
|
|
use crate::persist;
|
|
use crate::state::{AppState, AttachSink};
|
|
use axum::extract::ws::{Message, WebSocket};
|
|
use axum::extract::{State, WebSocketUpgrade};
|
|
use axum::response::IntoResponse;
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use rand::RngCore;
|
|
use rsh_types::{
|
|
AttachIOFrame, BackendOpMsg, BackendStubMsg, OpEvent, OpMsg, OpReq, OpResp, SessionRecord,
|
|
SessionView,
|
|
};
|
|
use ssh_key::PublicKey;
|
|
use std::sync::Arc;
|
|
use tokio::sync::mpsc;
|
|
|
|
pub async fn handler(
|
|
ws: WebSocketUpgrade,
|
|
State(state): State<Arc<AppState>>,
|
|
) -> impl IntoResponse {
|
|
ws.on_upgrade(move |socket| run(socket, state))
|
|
}
|
|
|
|
async fn run(mut socket: WebSocket, state: Arc<AppState>) {
|
|
let pubkey = match auth_handshake(&mut socket, &state).await {
|
|
Ok(k) => k,
|
|
Err(reason) => {
|
|
let _ = send(&mut socket, &BackendOpMsg::AuthFail { reason }).await;
|
|
return;
|
|
}
|
|
};
|
|
if send(&mut socket, &BackendOpMsg::AuthOk).await.is_err() {
|
|
return;
|
|
}
|
|
tracing::info!(fingerprint = %pubkey.fingerprint(Default::default()), "operator authed");
|
|
|
|
let (out_tx, mut out_rx) = mpsc::channel::<BackendOpMsg>(128);
|
|
let (mut sink, mut stream) = socket.split();
|
|
|
|
let writer = tokio::spawn(async move {
|
|
while let Some(msg) = out_rx.recv().await {
|
|
let text = match serde_json::to_string(&msg) {
|
|
Ok(t) => t,
|
|
Err(_) => break,
|
|
};
|
|
if sink.send(Message::Text(text)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
let _ = sink.close().await;
|
|
});
|
|
|
|
let mut attached: Option<(String, u64, u64)> = None;
|
|
|
|
while let Some(Ok(msg)) = stream.next().await {
|
|
let text = match msg {
|
|
Message::Text(t) => t,
|
|
Message::Binary(b) => match String::from_utf8(b) {
|
|
Ok(s) => s,
|
|
Err(_) => continue,
|
|
},
|
|
Message::Close(_) => break,
|
|
_ => continue,
|
|
};
|
|
let op: OpMsg = match serde_json::from_str(&text) {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
tracing::warn!("bad op msg: {e}");
|
|
continue;
|
|
}
|
|
};
|
|
let (req_id, body) = match op {
|
|
OpMsg::Req { id, body } => (id, body),
|
|
_ => continue,
|
|
};
|
|
let resp = handle_req(&state, &out_tx, &mut attached, req_id, body).await;
|
|
if let Some(r) = resp {
|
|
if out_tx.send(BackendOpMsg::Resp { id: req_id, body: r }).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some((s, c, _)) = attached {
|
|
if let Some(handle) = state.connections.get(&(s, c)) {
|
|
let mut a = handle.attach.lock().await;
|
|
*a = None;
|
|
}
|
|
}
|
|
drop(out_tx);
|
|
let _ = writer.await;
|
|
}
|
|
|
|
async fn auth_handshake(socket: &mut WebSocket, state: &Arc<AppState>) -> Result<PublicKey, String> {
|
|
let init = recv_op(socket).await.ok_or_else(|| "no message".to_string())?;
|
|
let offered_str = match init {
|
|
OpMsg::AuthInit { pubkey_openssh } => pubkey_openssh,
|
|
_ => return Err("expected AuthInit".into()),
|
|
};
|
|
let offered = PublicKey::from_openssh(&offered_str).map_err(|e| format!("bad pubkey: {e}"))?;
|
|
let keys = state.authorized_keys.read().await;
|
|
let matched = find_key(&keys, &offered).cloned();
|
|
drop(keys);
|
|
let matched = matched.ok_or_else(|| "key not authorized".to_string())?;
|
|
|
|
let mut nonce = [0u8; 32];
|
|
rand::thread_rng().fill_bytes(&mut nonce);
|
|
send(socket, &BackendOpMsg::Challenge { nonce }).await.map_err(|e| e.to_string())?;
|
|
let signed = recv_op(socket).await.ok_or_else(|| "no signature".to_string())?;
|
|
let (sig_bytes, _alg) = match signed {
|
|
OpMsg::AuthSign { signature, alg } => (signature, alg),
|
|
_ => return Err("expected AuthSign".into()),
|
|
};
|
|
verify_signature(&matched, &nonce, &sig_bytes).map_err(|e| e.to_string())?;
|
|
Ok(matched)
|
|
}
|
|
|
|
async fn handle_req(
|
|
state: &Arc<AppState>,
|
|
out_tx: &mpsc::Sender<BackendOpMsg>,
|
|
attached: &mut Option<(String, u64, u64)>,
|
|
req_id: u64,
|
|
body: OpReq,
|
|
) -> Option<OpResp> {
|
|
match body {
|
|
OpReq::SessionList => Some(OpResp::Sessions(list_sessions(state).await)),
|
|
OpReq::SessionCreate { name, password_hash } => {
|
|
let mut s = state.sessions.write().await;
|
|
if s.contains_key(&name) {
|
|
return Some(OpResp::Err(format!("session '{name}' already exists")));
|
|
}
|
|
let rec = SessionRecord {
|
|
id: name.clone(),
|
|
password_hash,
|
|
created_at: now_unix(),
|
|
};
|
|
s.insert(name.clone(), rec.clone());
|
|
let snapshot: Vec<_> = s.values().cloned().collect();
|
|
drop(s);
|
|
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
|
return Some(OpResp::Err(format!("persist: {e}")));
|
|
}
|
|
let view = SessionView {
|
|
id: rec.id.clone(),
|
|
has_password: rec.password_hash.is_some(),
|
|
created_at: rec.created_at,
|
|
connection_count: 0,
|
|
};
|
|
let _ = state.event_bus.send(OpEvent::NewSession(view));
|
|
Some(OpResp::Ok)
|
|
}
|
|
OpReq::SessionDelete { name, disconnect } => {
|
|
let mut s = state.sessions.write().await;
|
|
if s.remove(&name).is_none() {
|
|
return Some(OpResp::Err(format!("no such session '{name}'")));
|
|
}
|
|
let snapshot: Vec<_> = s.values().cloned().collect();
|
|
drop(s);
|
|
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
|
return Some(OpResp::Err(format!("persist: {e}")));
|
|
}
|
|
if disconnect {
|
|
disconnect_session(state, &name).await;
|
|
}
|
|
let _ = state.event_bus.send(OpEvent::SessionDeleted { session: name });
|
|
Some(OpResp::Ok)
|
|
}
|
|
OpReq::SessionUpdate { name, set_password_hash, disconnect } => {
|
|
let mut s = state.sessions.write().await;
|
|
let Some(rec) = s.get_mut(&name) else {
|
|
return Some(OpResp::Err(format!("no such session '{name}'")));
|
|
};
|
|
if let Some(pw) = set_password_hash {
|
|
rec.password_hash = pw;
|
|
}
|
|
let snapshot: Vec<_> = s.values().cloned().collect();
|
|
drop(s);
|
|
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
|
return Some(OpResp::Err(format!("persist: {e}")));
|
|
}
|
|
if disconnect {
|
|
disconnect_session(state, &name).await;
|
|
}
|
|
Some(OpResp::Ok)
|
|
}
|
|
OpReq::ConnectionList { session } => {
|
|
Some(OpResp::Connections(state.list_connections(session.as_deref())))
|
|
}
|
|
OpReq::Attach { session, connection_id, pty: _, cols, rows } => {
|
|
let conn_id = match connection_id {
|
|
Some(c) => c,
|
|
None => {
|
|
let mut found = None;
|
|
for kv in state.connections.iter() {
|
|
if kv.key().0 == session {
|
|
if found.is_some() {
|
|
return Some(OpResp::Err("multiple connections; specify id".into()));
|
|
}
|
|
found = Some(kv.key().1);
|
|
}
|
|
}
|
|
match found {
|
|
Some(c) => c,
|
|
None => return Some(OpResp::Err("no connections".into())),
|
|
}
|
|
}
|
|
};
|
|
let Some(handle) = state.connections.get(&(session.clone(), conn_id)).map(|h| h.clone()) else {
|
|
return Some(OpResp::Err("connection not found".into()));
|
|
};
|
|
{
|
|
let mut a = handle.attach.lock().await;
|
|
*a = Some(AttachSink { req_id, sender: out_tx.clone() });
|
|
}
|
|
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
|
|
*attached = Some((session, conn_id, req_id));
|
|
Some(OpResp::AttachReady { connection_id: conn_id })
|
|
}
|
|
OpReq::AttachIO(frame) => {
|
|
let Some((session, conn_id, _)) = attached.clone() else {
|
|
return Some(OpResp::Err("not attached".into()));
|
|
};
|
|
let Some(handle) = state.connections.get(&(session, conn_id)).map(|h| h.clone()) else {
|
|
return Some(OpResp::Err("connection gone".into()));
|
|
};
|
|
match frame {
|
|
AttachIOFrame::Stdin(b) => {
|
|
let _ = handle.to_stub.send(BackendStubMsg::Stdin(b)).await;
|
|
}
|
|
AttachIOFrame::Resize { cols, rows } => {
|
|
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
|
|
}
|
|
AttachIOFrame::Kill => {
|
|
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
|
|
}
|
|
AttachIOFrame::Eof => {
|
|
let _ = handle.to_stub.send(BackendStubMsg::Stdin(Vec::new())).await;
|
|
}
|
|
}
|
|
None
|
|
}
|
|
OpReq::Detach => {
|
|
if let Some((s, c, _)) = attached.take() {
|
|
if let Some(handle) = state.connections.get(&(s, c)) {
|
|
let mut a = handle.attach.lock().await;
|
|
*a = None;
|
|
}
|
|
}
|
|
Some(OpResp::Ok)
|
|
}
|
|
OpReq::KeysList => {
|
|
let text = persist::load_authorized_keys_text(&state.cfg.authorized_keys_path())
|
|
.await
|
|
.unwrap_or_default();
|
|
let keys: Vec<String> = text
|
|
.lines()
|
|
.map(|l| l.trim().to_string())
|
|
.filter(|l| !l.is_empty())
|
|
.collect();
|
|
Some(OpResp::Keys(keys))
|
|
}
|
|
OpReq::KeysAppend { keys } => {
|
|
let path = state.cfg.authorized_keys_path();
|
|
let mut text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
|
|
for k in keys {
|
|
let k = k.trim();
|
|
if k.is_empty() {
|
|
continue;
|
|
}
|
|
if !text.lines().any(|l| l.trim() == k) {
|
|
if !text.is_empty() && !text.ends_with('\n') {
|
|
text.push('\n');
|
|
}
|
|
text.push_str(k);
|
|
text.push('\n');
|
|
}
|
|
}
|
|
if let Err(e) = persist::save_authorized_keys_text(&path, &text).await {
|
|
return Some(OpResp::Err(format!("persist: {e}")));
|
|
}
|
|
reload_authorized_keys(state, &text).await;
|
|
Some(OpResp::Ok)
|
|
}
|
|
OpReq::KeysRemove { keys } => {
|
|
let path = state.cfg.authorized_keys_path();
|
|
let text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
|
|
let targets: Vec<String> = keys.iter().map(|k| k.trim().to_string()).collect();
|
|
let new: String = text
|
|
.lines()
|
|
.filter(|l| {
|
|
let lt = l.trim();
|
|
!targets.iter().any(|t| t == lt)
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
let new = if new.is_empty() { new } else { format!("{}\n", new) };
|
|
if let Err(e) = persist::save_authorized_keys_text(&path, &new).await {
|
|
return Some(OpResp::Err(format!("persist: {e}")));
|
|
}
|
|
reload_authorized_keys(state, &new).await;
|
|
Some(OpResp::Ok)
|
|
}
|
|
OpReq::KeysReplace { content } => {
|
|
let path = state.cfg.authorized_keys_path();
|
|
if let Err(e) = persist::save_authorized_keys_text(&path, &content).await {
|
|
return Some(OpResp::Err(format!("persist: {e}")));
|
|
}
|
|
reload_authorized_keys(state, &content).await;
|
|
Some(OpResp::Ok)
|
|
}
|
|
OpReq::Watch { session } => {
|
|
let mut rx = state.event_bus.subscribe();
|
|
let tx = out_tx.clone();
|
|
tokio::spawn(async move {
|
|
while let Ok(ev) = rx.recv().await {
|
|
let pass = match &ev {
|
|
OpEvent::NewConnection(v) => session.as_ref().map_or(true, |s| &v.session_id == s),
|
|
OpEvent::ConnectionClosed { session: s, .. } => session.as_ref().map_or(true, |x| x == s),
|
|
OpEvent::NewSession(v) => session.as_ref().map_or(true, |s| &v.id == s),
|
|
OpEvent::SessionDeleted { session: s } => session.as_ref().map_or(true, |x| x == s),
|
|
};
|
|
if !pass {
|
|
continue;
|
|
}
|
|
if tx.send(BackendOpMsg::Event(ev)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
Some(OpResp::WatchStarted)
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn list_sessions(state: &Arc<AppState>) -> Vec<SessionView> {
|
|
let s = state.sessions.read().await;
|
|
s.values()
|
|
.map(|r| SessionView {
|
|
id: r.id.clone(),
|
|
has_password: r.password_hash.is_some(),
|
|
created_at: r.created_at,
|
|
connection_count: state.connection_count(&r.id),
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
async fn disconnect_session(state: &Arc<AppState>, name: &str) {
|
|
let mut to_kill = Vec::new();
|
|
for kv in state.connections.iter() {
|
|
if kv.key().0 == name {
|
|
to_kill.push((kv.key().clone(), kv.value().clone()));
|
|
}
|
|
}
|
|
for (_, handle) in &to_kill {
|
|
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
|
|
}
|
|
}
|
|
|
|
async fn reload_authorized_keys(state: &Arc<AppState>, text: &str) {
|
|
let parsed = parse_authorized_keys(text);
|
|
let mut k = state.authorized_keys.write().await;
|
|
*k = parsed;
|
|
}
|
|
|
|
async fn send(socket: &mut WebSocket, msg: &BackendOpMsg) -> Result<(), axum::Error> {
|
|
let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?;
|
|
socket.send(Message::Text(t)).await
|
|
}
|
|
|
|
async fn recv_op(socket: &mut WebSocket) -> Option<OpMsg> {
|
|
loop {
|
|
match socket.recv().await? {
|
|
Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(),
|
|
Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(),
|
|
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue,
|
|
Ok(Message::Close(_)) => return None,
|
|
Err(_) => return None,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn now_unix() -> i64 {
|
|
std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.map(|d| d.as_secs() as i64)
|
|
.unwrap_or(0)
|
|
}
|