Files
rsh/crates/rsh-backend/src/state.rs
2026-05-12 23:19:12 +09:00

84 lines
2.6 KiB
Rust

use crate::config::Config;
use crate::keys::parse_authorized_keys;
use dashmap::DashMap;
use rsh_types::{BackendOpMsg, BackendStubMsg, OpEvent, SessionRecord, StubInfo};
use ssh_key::PublicKey;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex, RwLock};
pub struct ConnHandle {
pub info: StubInfo,
pub to_stub: mpsc::Sender<BackendStubMsg>,
pub attach: Mutex<Option<AttachSink>>,
pub connected_at: i64,
pub extra_shells: DashMap<u64, Mutex<Option<AttachSink>>>,
pub next_shell_id: AtomicU64,
}
pub struct AttachSink {
pub req_id: u64,
pub sender: mpsc::Sender<BackendOpMsg>,
}
pub struct AppState {
pub cfg: Config,
pub sessions: RwLock<HashMap<String, SessionRecord>>,
pub connections: DashMap<(String, u64), Arc<ConnHandle>>,
pub next_conn_id: DashMap<String, AtomicU64>,
pub authorized_keys: RwLock<Vec<PublicKey>>,
pub env_keys: Vec<PublicKey>,
pub event_bus: broadcast::Sender<OpEvent>,
pub spawn_shell_pending: DashMap<(String, u64, u64), oneshot::Sender<()>>,
}
impl AppState {
pub fn new(cfg: Config) -> Self {
let (tx, _) = broadcast::channel(256);
let env_keys = cfg
.authorized_keys_env
.as_deref()
.map(parse_authorized_keys)
.unwrap_or_default();
Self {
cfg,
sessions: RwLock::new(HashMap::new()),
connections: DashMap::new(),
next_conn_id: DashMap::new(),
authorized_keys: RwLock::new(Vec::new()),
env_keys,
event_bus: tx,
spawn_shell_pending: DashMap::new(),
}
}
pub fn alloc_conn_id(&self, session: &str) -> u64 {
let entry = self
.next_conn_id
.entry(session.to_string())
.or_insert_with(|| AtomicU64::new(1));
entry.fetch_add(1, Ordering::Relaxed)
}
pub fn connection_count(&self, session: &str) -> u32 {
self.connections
.iter()
.filter(|kv| kv.key().0 == session)
.count() as u32
}
pub fn list_connections(&self, filter: Option<&str>) -> Vec<rsh_types::ConnectionView> {
self.connections
.iter()
.filter(|kv| filter.map_or(true, |s| kv.key().0 == s))
.map(|kv| rsh_types::ConnectionView {
session_id: kv.key().0.clone(),
connection_id: kv.key().1,
info: kv.value().info.clone(),
connected_at: kv.value().connected_at,
})
.collect()
}
}