initial commit
This commit is contained in:
24
crates/rsh-backend/Cargo.toml
Normal file
24
crates/rsh-backend/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "rsh-backend"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
rsh-types = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
argon2 = { workspace = true }
|
||||
ssh-key = { workspace = true }
|
||||
signature = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
time = { workspace = true }
|
||||
29
crates/rsh-backend/src/auth.rs
Normal file
29
crates/rsh-backend/src/auth.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use argon2::password_hash::{PasswordHash, PasswordVerifier};
|
||||
use argon2::Argon2;
|
||||
use ssh_key::public::PublicKey;
|
||||
use ssh_key::SshSig;
|
||||
|
||||
pub fn verify_password(password: &str, hash: &str) -> bool {
|
||||
let parsed = match PasswordHash::new(hash) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return false,
|
||||
};
|
||||
Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed)
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
pub fn verify_signature(pubkey: &PublicKey, nonce: &[u8], signature_blob: &[u8]) -> Result<()> {
|
||||
let sig = SshSig::from_pem(signature_blob)
|
||||
.map_err(|e| anyhow!("parse signature: {e}"))?;
|
||||
pubkey
|
||||
.verify("rsh-auth", nonce, &sig)
|
||||
.map_err(|e| anyhow!("verify failed: {e}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn find_key<'a>(keys: &'a [PublicKey], offered: &PublicKey) -> Option<&'a PublicKey> {
|
||||
let fp = offered.fingerprint(Default::default());
|
||||
keys.iter().find(|k| k.fingerprint(Default::default()) == fp)
|
||||
}
|
||||
30
crates/rsh-backend/src/config.rs
Normal file
30
crates/rsh-backend/src/config.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
pub data_dir: PathBuf,
|
||||
pub bind: SocketAddr,
|
||||
pub log: String,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_env() -> anyhow::Result<Self> {
|
||||
let data_dir = std::env::var("RSH_DATA")
|
||||
.unwrap_or_else(|_| "/var/lib/rsh".to_string())
|
||||
.into();
|
||||
let bind: SocketAddr = std::env::var("RSH_BIND")
|
||||
.unwrap_or_else(|_| "0.0.0.0:7777".to_string())
|
||||
.parse()?;
|
||||
let log = std::env::var("RSH_LOG").unwrap_or_else(|_| "info,tower_http=warn".to_string());
|
||||
Ok(Self { data_dir, bind, log })
|
||||
}
|
||||
|
||||
pub fn sessions_path(&self) -> PathBuf {
|
||||
self.data_dir.join("sessions.json")
|
||||
}
|
||||
|
||||
pub fn authorized_keys_path(&self) -> PathBuf {
|
||||
self.data_dir.join("authorized_keys")
|
||||
}
|
||||
}
|
||||
10
crates/rsh-backend/src/keys.rs
Normal file
10
crates/rsh-backend/src/keys.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use ssh_key::PublicKey;
|
||||
|
||||
pub fn parse_authorized_keys(content: &str) -> Vec<PublicKey> {
|
||||
content
|
||||
.lines()
|
||||
.map(|l| l.trim())
|
||||
.filter(|l| !l.is_empty() && !l.starts_with('#'))
|
||||
.filter_map(|l| PublicKey::from_openssh(l).ok())
|
||||
.collect()
|
||||
}
|
||||
60
crates/rsh-backend/src/main.rs
Normal file
60
crates/rsh-backend/src/main.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
mod auth;
|
||||
mod config;
|
||||
mod keys;
|
||||
mod persist;
|
||||
mod state;
|
||||
mod ws_op;
|
||||
mod ws_stub;
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use config::Config;
|
||||
use state::AppState;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let cfg = Config::from_env()?;
|
||||
let filter = EnvFilter::try_new(&cfg.log).unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
tracing_subscriber::fmt().with_env_filter(filter).init();
|
||||
|
||||
tokio::fs::create_dir_all(&cfg.data_dir).await.ok();
|
||||
|
||||
let state = Arc::new(AppState::new(cfg.clone()));
|
||||
|
||||
let sessions = persist::load_sessions(&cfg.sessions_path())
|
||||
.await
|
||||
.context("load sessions")?;
|
||||
{
|
||||
let mut map = state.sessions.write().await;
|
||||
let mut h: HashMap<String, _> = HashMap::new();
|
||||
for s in sessions {
|
||||
h.insert(s.id.clone(), s);
|
||||
}
|
||||
*map = h;
|
||||
}
|
||||
|
||||
let keys_text = persist::load_authorized_keys_text(&cfg.authorized_keys_path())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
{
|
||||
let mut k = state.authorized_keys.write().await;
|
||||
*k = keys::parse_authorized_keys(&keys_text);
|
||||
tracing::info!(count = k.len(), "loaded authorized keys");
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/healthz", get(|| async { "ok" }))
|
||||
.route("/ws/stub", get(ws_stub::handler))
|
||||
.route("/ws/op", get(ws_op::handler))
|
||||
.with_state(state.clone())
|
||||
.layer(tower_http::trace::TraceLayer::new_for_http());
|
||||
|
||||
tracing::info!(bind = %cfg.bind, data = ?cfg.data_dir, "rsh-backend listening");
|
||||
let listener = tokio::net::TcpListener::bind(cfg.bind).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
Ok(())
|
||||
}
|
||||
45
crates/rsh-backend/src/persist.rs
Normal file
45
crates/rsh-backend/src/persist.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use anyhow::Context;
|
||||
use rsh_types::SessionRecord;
|
||||
use std::path::Path;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
pub async fn load_sessions(path: &Path) -> anyhow::Result<Vec<SessionRecord>> {
|
||||
if !path.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let data = fs::read(path).await.with_context(|| format!("read {:?}", path))?;
|
||||
let v: Vec<SessionRecord> = serde_json::from_slice(&data)?;
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
pub async fn save_sessions(path: &Path, sessions: &[SessionRecord]) -> anyhow::Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await.ok();
|
||||
}
|
||||
let tmp = path.with_extension("json.tmp");
|
||||
let bytes = serde_json::to_vec_pretty(sessions)?;
|
||||
let mut f = fs::File::create(&tmp).await?;
|
||||
f.write_all(&bytes).await?;
|
||||
f.sync_all().await?;
|
||||
drop(f);
|
||||
fs::rename(&tmp, path).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn load_authorized_keys_text(path: &Path) -> anyhow::Result<String> {
|
||||
if !path.exists() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
Ok(fs::read_to_string(path).await?)
|
||||
}
|
||||
|
||||
pub async fn save_authorized_keys_text(path: &Path, content: &str) -> anyhow::Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await.ok();
|
||||
}
|
||||
let tmp = path.with_extension("tmp");
|
||||
fs::write(&tmp, content.as_bytes()).await?;
|
||||
fs::rename(&tmp, path).await?;
|
||||
Ok(())
|
||||
}
|
||||
71
crates/rsh-backend/src/state.rs
Normal file
71
crates/rsh-backend/src/state.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use crate::config::Config;
|
||||
use dashmap::DashMap;
|
||||
use rsh_types::{BackendOpMsg, BackendStubMsg, OpEvent, SessionRecord, StubInfo};
|
||||
use ssh_key::PublicKey;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
|
||||
|
||||
pub struct ConnHandle {
|
||||
pub info: StubInfo,
|
||||
pub to_stub: mpsc::Sender<BackendStubMsg>,
|
||||
pub attach: Mutex<Option<AttachSink>>,
|
||||
pub connected_at: i64,
|
||||
}
|
||||
|
||||
pub struct AttachSink {
|
||||
pub req_id: u64,
|
||||
pub sender: mpsc::Sender<BackendOpMsg>,
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
pub cfg: Config,
|
||||
pub sessions: RwLock<HashMap<String, SessionRecord>>,
|
||||
pub connections: DashMap<(String, u64), Arc<ConnHandle>>,
|
||||
pub next_conn_id: DashMap<String, AtomicU64>,
|
||||
pub authorized_keys: RwLock<Vec<PublicKey>>,
|
||||
pub event_bus: broadcast::Sender<OpEvent>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(cfg: Config) -> Self {
|
||||
let (tx, _) = broadcast::channel(256);
|
||||
Self {
|
||||
cfg,
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
connections: DashMap::new(),
|
||||
next_conn_id: DashMap::new(),
|
||||
authorized_keys: RwLock::new(Vec::new()),
|
||||
event_bus: tx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc_conn_id(&self, session: &str) -> u64 {
|
||||
let entry = self
|
||||
.next_conn_id
|
||||
.entry(session.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(1));
|
||||
entry.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn connection_count(&self, session: &str) -> u32 {
|
||||
self.connections
|
||||
.iter()
|
||||
.filter(|kv| kv.key().0 == session)
|
||||
.count() as u32
|
||||
}
|
||||
|
||||
pub fn list_connections(&self, filter: Option<&str>) -> Vec<rsh_types::ConnectionView> {
|
||||
self.connections
|
||||
.iter()
|
||||
.filter(|kv| filter.map_or(true, |s| kv.key().0 == s))
|
||||
.map(|kv| rsh_types::ConnectionView {
|
||||
session_id: kv.key().0.clone(),
|
||||
connection_id: kv.key().1,
|
||||
info: kv.value().info.clone(),
|
||||
connected_at: kv.value().connected_at,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
388
crates/rsh-backend/src/ws_op.rs
Normal file
388
crates/rsh-backend/src/ws_op.rs
Normal file
@@ -0,0 +1,388 @@
|
||||
use crate::auth::{find_key, verify_signature};
|
||||
use crate::keys::parse_authorized_keys;
|
||||
use crate::persist;
|
||||
use crate::state::{AppState, AttachSink};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use axum::extract::{State, WebSocketUpgrade};
|
||||
use axum::response::IntoResponse;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use rand::RngCore;
|
||||
use rsh_types::{
|
||||
AttachIOFrame, BackendOpMsg, BackendStubMsg, OpEvent, OpMsg, OpReq, OpResp, SessionRecord,
|
||||
SessionView,
|
||||
};
|
||||
use ssh_key::PublicKey;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub async fn handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| run(socket, state))
|
||||
}
|
||||
|
||||
async fn run(mut socket: WebSocket, state: Arc<AppState>) {
|
||||
let pubkey = match auth_handshake(&mut socket, &state).await {
|
||||
Ok(k) => k,
|
||||
Err(reason) => {
|
||||
let _ = send(&mut socket, &BackendOpMsg::AuthFail { reason }).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
if send(&mut socket, &BackendOpMsg::AuthOk).await.is_err() {
|
||||
return;
|
||||
}
|
||||
tracing::info!(fingerprint = %pubkey.fingerprint(Default::default()), "operator authed");
|
||||
|
||||
let (out_tx, mut out_rx) = mpsc::channel::<BackendOpMsg>(128);
|
||||
let (mut sink, mut stream) = socket.split();
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
while let Some(msg) = out_rx.recv().await {
|
||||
let text = match serde_json::to_string(&msg) {
|
||||
Ok(t) => t,
|
||||
Err(_) => break,
|
||||
};
|
||||
if sink.send(Message::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = sink.close().await;
|
||||
});
|
||||
|
||||
let mut attached: Option<(String, u64, u64)> = None;
|
||||
|
||||
while let Some(Ok(msg)) = stream.next().await {
|
||||
let text = match msg {
|
||||
Message::Text(t) => t,
|
||||
Message::Binary(b) => match String::from_utf8(b) {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
},
|
||||
Message::Close(_) => break,
|
||||
_ => continue,
|
||||
};
|
||||
let op: OpMsg = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!("bad op msg: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let (req_id, body) = match op {
|
||||
OpMsg::Req { id, body } => (id, body),
|
||||
_ => continue,
|
||||
};
|
||||
let resp = handle_req(&state, &out_tx, &mut attached, req_id, body).await;
|
||||
if let Some(r) = resp {
|
||||
if out_tx.send(BackendOpMsg::Resp { id: req_id, body: r }).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((s, c, _)) = attached {
|
||||
if let Some(handle) = state.connections.get(&(s, c)) {
|
||||
let mut a = handle.attach.lock().await;
|
||||
*a = None;
|
||||
}
|
||||
}
|
||||
drop(out_tx);
|
||||
let _ = writer.await;
|
||||
}
|
||||
|
||||
async fn auth_handshake(socket: &mut WebSocket, state: &Arc<AppState>) -> Result<PublicKey, String> {
|
||||
let init = recv_op(socket).await.ok_or_else(|| "no message".to_string())?;
|
||||
let offered_str = match init {
|
||||
OpMsg::AuthInit { pubkey_openssh } => pubkey_openssh,
|
||||
_ => return Err("expected AuthInit".into()),
|
||||
};
|
||||
let offered = PublicKey::from_openssh(&offered_str).map_err(|e| format!("bad pubkey: {e}"))?;
|
||||
let keys = state.authorized_keys.read().await;
|
||||
let matched = find_key(&keys, &offered).cloned();
|
||||
drop(keys);
|
||||
let matched = matched.ok_or_else(|| "key not authorized".to_string())?;
|
||||
|
||||
let mut nonce = [0u8; 32];
|
||||
rand::thread_rng().fill_bytes(&mut nonce);
|
||||
send(socket, &BackendOpMsg::Challenge { nonce }).await.map_err(|e| e.to_string())?;
|
||||
let signed = recv_op(socket).await.ok_or_else(|| "no signature".to_string())?;
|
||||
let (sig_bytes, _alg) = match signed {
|
||||
OpMsg::AuthSign { signature, alg } => (signature, alg),
|
||||
_ => return Err("expected AuthSign".into()),
|
||||
};
|
||||
verify_signature(&matched, &nonce, &sig_bytes).map_err(|e| e.to_string())?;
|
||||
Ok(matched)
|
||||
}
|
||||
|
||||
async fn handle_req(
|
||||
state: &Arc<AppState>,
|
||||
out_tx: &mpsc::Sender<BackendOpMsg>,
|
||||
attached: &mut Option<(String, u64, u64)>,
|
||||
req_id: u64,
|
||||
body: OpReq,
|
||||
) -> Option<OpResp> {
|
||||
match body {
|
||||
OpReq::SessionList => Some(OpResp::Sessions(list_sessions(state).await)),
|
||||
OpReq::SessionCreate { name, password_hash } => {
|
||||
let mut s = state.sessions.write().await;
|
||||
if s.contains_key(&name) {
|
||||
return Some(OpResp::Err(format!("session '{name}' already exists")));
|
||||
}
|
||||
let rec = SessionRecord {
|
||||
id: name.clone(),
|
||||
password_hash,
|
||||
created_at: now_unix(),
|
||||
};
|
||||
s.insert(name.clone(), rec.clone());
|
||||
let snapshot: Vec<_> = s.values().cloned().collect();
|
||||
drop(s);
|
||||
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
let view = SessionView {
|
||||
id: rec.id.clone(),
|
||||
has_password: rec.password_hash.is_some(),
|
||||
created_at: rec.created_at,
|
||||
connection_count: 0,
|
||||
};
|
||||
let _ = state.event_bus.send(OpEvent::NewSession(view));
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::SessionDelete { name, disconnect } => {
|
||||
let mut s = state.sessions.write().await;
|
||||
if s.remove(&name).is_none() {
|
||||
return Some(OpResp::Err(format!("no such session '{name}'")));
|
||||
}
|
||||
let snapshot: Vec<_> = s.values().cloned().collect();
|
||||
drop(s);
|
||||
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
if disconnect {
|
||||
disconnect_session(state, &name).await;
|
||||
}
|
||||
let _ = state.event_bus.send(OpEvent::SessionDeleted { session: name });
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::SessionUpdate { name, set_password_hash, disconnect } => {
|
||||
let mut s = state.sessions.write().await;
|
||||
let Some(rec) = s.get_mut(&name) else {
|
||||
return Some(OpResp::Err(format!("no such session '{name}'")));
|
||||
};
|
||||
if let Some(pw) = set_password_hash {
|
||||
rec.password_hash = pw;
|
||||
}
|
||||
let snapshot: Vec<_> = s.values().cloned().collect();
|
||||
drop(s);
|
||||
if let Err(e) = persist::save_sessions(&state.cfg.sessions_path(), &snapshot).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
if disconnect {
|
||||
disconnect_session(state, &name).await;
|
||||
}
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::ConnectionList { session } => {
|
||||
Some(OpResp::Connections(state.list_connections(session.as_deref())))
|
||||
}
|
||||
OpReq::Attach { session, connection_id, pty: _, cols, rows } => {
|
||||
let conn_id = match connection_id {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
let mut found = None;
|
||||
for kv in state.connections.iter() {
|
||||
if kv.key().0 == session {
|
||||
if found.is_some() {
|
||||
return Some(OpResp::Err("multiple connections; specify id".into()));
|
||||
}
|
||||
found = Some(kv.key().1);
|
||||
}
|
||||
}
|
||||
match found {
|
||||
Some(c) => c,
|
||||
None => return Some(OpResp::Err("no connections".into())),
|
||||
}
|
||||
}
|
||||
};
|
||||
let Some(handle) = state.connections.get(&(session.clone(), conn_id)).map(|h| h.clone()) else {
|
||||
return Some(OpResp::Err("connection not found".into()));
|
||||
};
|
||||
{
|
||||
let mut a = handle.attach.lock().await;
|
||||
*a = Some(AttachSink { req_id, sender: out_tx.clone() });
|
||||
}
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
|
||||
*attached = Some((session, conn_id, req_id));
|
||||
Some(OpResp::AttachReady { connection_id: conn_id })
|
||||
}
|
||||
OpReq::AttachIO(frame) => {
|
||||
let Some((session, conn_id, _)) = attached.clone() else {
|
||||
return Some(OpResp::Err("not attached".into()));
|
||||
};
|
||||
let Some(handle) = state.connections.get(&(session, conn_id)).map(|h| h.clone()) else {
|
||||
return Some(OpResp::Err("connection gone".into()));
|
||||
};
|
||||
match frame {
|
||||
AttachIOFrame::Stdin(b) => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Stdin(b)).await;
|
||||
}
|
||||
AttachIOFrame::Resize { cols, rows } => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Resize { cols, rows }).await;
|
||||
}
|
||||
AttachIOFrame::Kill => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
|
||||
}
|
||||
AttachIOFrame::Eof => {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Stdin(Vec::new())).await;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
OpReq::Detach => {
|
||||
if let Some((s, c, _)) = attached.take() {
|
||||
if let Some(handle) = state.connections.get(&(s, c)) {
|
||||
let mut a = handle.attach.lock().await;
|
||||
*a = None;
|
||||
}
|
||||
}
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::KeysList => {
|
||||
let text = persist::load_authorized_keys_text(&state.cfg.authorized_keys_path())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let keys: Vec<String> = text
|
||||
.lines()
|
||||
.map(|l| l.trim().to_string())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect();
|
||||
Some(OpResp::Keys(keys))
|
||||
}
|
||||
OpReq::KeysAppend { keys } => {
|
||||
let path = state.cfg.authorized_keys_path();
|
||||
let mut text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
|
||||
for k in keys {
|
||||
let k = k.trim();
|
||||
if k.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !text.lines().any(|l| l.trim() == k) {
|
||||
if !text.is_empty() && !text.ends_with('\n') {
|
||||
text.push('\n');
|
||||
}
|
||||
text.push_str(k);
|
||||
text.push('\n');
|
||||
}
|
||||
}
|
||||
if let Err(e) = persist::save_authorized_keys_text(&path, &text).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
reload_authorized_keys(state, &text).await;
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::KeysRemove { keys } => {
|
||||
let path = state.cfg.authorized_keys_path();
|
||||
let text = persist::load_authorized_keys_text(&path).await.unwrap_or_default();
|
||||
let targets: Vec<String> = keys.iter().map(|k| k.trim().to_string()).collect();
|
||||
let new: String = text
|
||||
.lines()
|
||||
.filter(|l| {
|
||||
let lt = l.trim();
|
||||
!targets.iter().any(|t| t == lt)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let new = if new.is_empty() { new } else { format!("{}\n", new) };
|
||||
if let Err(e) = persist::save_authorized_keys_text(&path, &new).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
reload_authorized_keys(state, &new).await;
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::KeysReplace { content } => {
|
||||
let path = state.cfg.authorized_keys_path();
|
||||
if let Err(e) = persist::save_authorized_keys_text(&path, &content).await {
|
||||
return Some(OpResp::Err(format!("persist: {e}")));
|
||||
}
|
||||
reload_authorized_keys(state, &content).await;
|
||||
Some(OpResp::Ok)
|
||||
}
|
||||
OpReq::Watch { session } => {
|
||||
let mut rx = state.event_bus.subscribe();
|
||||
let tx = out_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(ev) = rx.recv().await {
|
||||
let pass = match &ev {
|
||||
OpEvent::NewConnection(v) => session.as_ref().map_or(true, |s| &v.session_id == s),
|
||||
OpEvent::ConnectionClosed { session: s, .. } => session.as_ref().map_or(true, |x| x == s),
|
||||
OpEvent::NewSession(v) => session.as_ref().map_or(true, |s| &v.id == s),
|
||||
OpEvent::SessionDeleted { session: s } => session.as_ref().map_or(true, |x| x == s),
|
||||
};
|
||||
if !pass {
|
||||
continue;
|
||||
}
|
||||
if tx.send(BackendOpMsg::Event(ev)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
Some(OpResp::WatchStarted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_sessions(state: &Arc<AppState>) -> Vec<SessionView> {
|
||||
let s = state.sessions.read().await;
|
||||
s.values()
|
||||
.map(|r| SessionView {
|
||||
id: r.id.clone(),
|
||||
has_password: r.password_hash.is_some(),
|
||||
created_at: r.created_at,
|
||||
connection_count: state.connection_count(&r.id),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn disconnect_session(state: &Arc<AppState>, name: &str) {
|
||||
let mut to_kill = Vec::new();
|
||||
for kv in state.connections.iter() {
|
||||
if kv.key().0 == name {
|
||||
to_kill.push((kv.key().clone(), kv.value().clone()));
|
||||
}
|
||||
}
|
||||
for (_, handle) in &to_kill {
|
||||
let _ = handle.to_stub.send(BackendStubMsg::Kill).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn reload_authorized_keys(state: &Arc<AppState>, text: &str) {
|
||||
let parsed = parse_authorized_keys(text);
|
||||
let mut k = state.authorized_keys.write().await;
|
||||
*k = parsed;
|
||||
}
|
||||
|
||||
async fn send(socket: &mut WebSocket, msg: &BackendOpMsg) -> Result<(), axum::Error> {
|
||||
let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?;
|
||||
socket.send(Message::Text(t)).await
|
||||
}
|
||||
|
||||
async fn recv_op(socket: &mut WebSocket) -> Option<OpMsg> {
|
||||
loop {
|
||||
match socket.recv().await? {
|
||||
Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(),
|
||||
Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(),
|
||||
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue,
|
||||
Ok(Message::Close(_)) => return None,
|
||||
Err(_) => return None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn now_unix() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() as i64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
155
crates/rsh-backend/src/ws_stub.rs
Normal file
155
crates/rsh-backend/src/ws_stub.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use crate::auth::verify_password;
|
||||
use crate::state::{AppState, ConnHandle};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use axum::extract::{State, WebSocketUpgrade};
|
||||
use axum::response::IntoResponse;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use rsh_types::{BackendOpMsg, BackendStubMsg, ConnectionView, OpEvent, OpResp, StubMsg};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
pub async fn handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| run(socket, state))
|
||||
}
|
||||
|
||||
async fn run(mut socket: WebSocket, state: Arc<AppState>) {
|
||||
let hello = match recv_text::<StubMsg>(&mut socket).await {
|
||||
Some(StubMsg::Hello { session_id, password, info }) => (session_id, password, info),
|
||||
_ => {
|
||||
let _ = send(&mut socket, &BackendStubMsg::Rejected { reason: "expected Hello".into() }).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let (session_id, password, info) = hello;
|
||||
let sessions = state.sessions.read().await;
|
||||
let session = match sessions.get(&session_id) {
|
||||
Some(s) => s.clone(),
|
||||
None => {
|
||||
drop(sessions);
|
||||
let _ = send(&mut socket, &BackendStubMsg::Rejected { reason: "no such session".into() }).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
drop(sessions);
|
||||
if let Some(hash) = &session.password_hash {
|
||||
let provided = password.unwrap_or_default();
|
||||
if !verify_password(&provided, hash) {
|
||||
let _ = send(&mut socket, &BackendStubMsg::Rejected { reason: "bad password".into() }).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
let conn_id = state.alloc_conn_id(&session_id);
|
||||
let (to_stub_tx, mut to_stub_rx) = mpsc::channel::<BackendStubMsg>(64);
|
||||
let connected_at = now_unix();
|
||||
let handle = Arc::new(ConnHandle {
|
||||
info: info.clone(),
|
||||
to_stub: to_stub_tx.clone(),
|
||||
attach: Mutex::new(None),
|
||||
connected_at,
|
||||
});
|
||||
state.connections.insert((session_id.clone(), conn_id), handle.clone());
|
||||
let _ = state.event_bus.send(OpEvent::NewConnection(ConnectionView {
|
||||
session_id: session_id.clone(),
|
||||
connection_id: conn_id,
|
||||
info: info.clone(),
|
||||
connected_at,
|
||||
}));
|
||||
if send(&mut socket, &BackendStubMsg::Accepted { connection_id: conn_id }).await.is_err() {
|
||||
cleanup(&state, &session_id, conn_id).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let (mut ws_sink, mut ws_stream) = socket.split();
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
while let Some(msg) = to_stub_rx.recv().await {
|
||||
let text = match serde_json::to_string(&msg) {
|
||||
Ok(t) => t,
|
||||
Err(_) => break,
|
||||
};
|
||||
if ws_sink.send(Message::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = ws_sink.close().await;
|
||||
});
|
||||
|
||||
let state_r = state.clone();
|
||||
let session_r = session_id.clone();
|
||||
let handle_r = handle.clone();
|
||||
let reader = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = ws_stream.next().await {
|
||||
let text = match msg {
|
||||
Message::Text(t) => t,
|
||||
Message::Binary(b) => match String::from_utf8(b) {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
},
|
||||
Message::Close(_) => break,
|
||||
_ => continue,
|
||||
};
|
||||
let parsed: StubMsg = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
match parsed {
|
||||
StubMsg::Stdout(b) => forward_op(&handle_r, OpResp::Stdout(b)).await,
|
||||
StubMsg::Stderr(b) => forward_op(&handle_r, OpResp::Stderr(b)).await,
|
||||
StubMsg::Exited { code } => {
|
||||
forward_op(&handle_r, OpResp::Exited { code }).await;
|
||||
break;
|
||||
}
|
||||
StubMsg::Pong => {}
|
||||
StubMsg::Hello { .. } => {}
|
||||
}
|
||||
}
|
||||
cleanup(&state_r, &session_r, conn_id).await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(writer, reader);
|
||||
}
|
||||
|
||||
async fn forward_op(handle: &Arc<crate::state::ConnHandle>, resp: OpResp) {
|
||||
let attach = handle.attach.lock().await;
|
||||
if let Some(sink) = attach.as_ref() {
|
||||
let _ = sink
|
||||
.sender
|
||||
.send(BackendOpMsg::Resp { id: sink.req_id, body: resp })
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn cleanup(state: &Arc<AppState>, session_id: &str, conn_id: u64) {
|
||||
state.connections.remove(&(session_id.to_string(), conn_id));
|
||||
let _ = state.event_bus.send(OpEvent::ConnectionClosed {
|
||||
session: session_id.to_string(),
|
||||
connection_id: conn_id,
|
||||
});
|
||||
}
|
||||
|
||||
async fn send(socket: &mut WebSocket, msg: &BackendStubMsg) -> Result<(), axum::Error> {
|
||||
let t = serde_json::to_string(msg).map_err(|e| axum::Error::new(e))?;
|
||||
socket.send(Message::Text(t)).await
|
||||
}
|
||||
|
||||
async fn recv_text<T: serde::de::DeserializeOwned>(socket: &mut WebSocket) -> Option<T> {
|
||||
loop {
|
||||
match socket.recv().await? {
|
||||
Ok(Message::Text(t)) => return serde_json::from_str(&t).ok(),
|
||||
Ok(Message::Binary(b)) => return serde_json::from_slice(&b).ok(),
|
||||
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => continue,
|
||||
Ok(Message::Close(_)) => return None,
|
||||
Err(_) => return None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn now_unix() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() as i64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
Reference in New Issue
Block a user