use std::collections::HashMap; use std::sync::Arc; use sqlx::PgPool; use tokio::sync::{broadcast, mpsc, Mutex, RwLock}; use tokio::task::JoinHandle; use crate::db; use crate::irc::client::{ConnectOpts, IrcConnection}; use crate::irc::event::IrcEvent; use crate::irc::parser::IrcMessage; use crate::lua::api::{HttpApi, HttpRateLimiter, IrcApi, KvApi, LogApi, LogEntry, TimerApi, TimerEvent}; use crate::lua::dispatch::dispatch_event; use crate::lua::sandbox::{create_sandbox, reset_instruction_counter}; struct BotHandle { stop_tx: mpsc::Sender<()>, msg_tx: mpsc::Sender<(String, String)>, task: JoinHandle<()>, } pub struct BotManager { bots: RwLock>, log_channels: RwLock>>, db: PgPool, http_client: reqwest::Client, } impl BotManager { pub fn new(db: PgPool) -> Self { Self { bots: RwLock::new(HashMap::new()), log_channels: RwLock::new(HashMap::new()), db, http_client: reqwest::Client::new(), } } pub async fn subscribe_logs(&self, bot_id: &str) -> broadcast::Receiver { let mut channels = self.log_channels.write().await; let tx = channels .entry(bot_id.to_string()) .or_insert_with(|| broadcast::channel(256).0); tx.subscribe() } fn get_or_create_log_tx( channels: &mut HashMap>, bot_id: &str, ) -> broadcast::Sender { channels .entry(bot_id.to_string()) .or_insert_with(|| broadcast::channel(256).0) .clone() } pub async fn is_running(&self, bot_id: &str) -> bool { self.bots.read().await.contains_key(bot_id) } pub async fn start_bot(self: &Arc, bot_id: &str) -> Result<(), String> { { let bots = self.bots.read().await; if bots.contains_key(bot_id) { return Err("bot is already running".to_string()); } } let bot = db::get_bot(&self.db, bot_id) .await .map_err(|e| format!("db error: {e}"))? .ok_or_else(|| "bot not found".to_string())?; let scripts = db::list_enabled_scripts(&self.db, bot_id) .await .map_err(|e| format!("db error: {e}"))?; let log_tx = { let mut channels = self.log_channels.write().await; Self::get_or_create_log_tx(&mut channels, bot_id) }; let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1); let (msg_tx, mut msg_rx) = mpsc::channel::<(String, String)>(64); let db = self.db.clone(); let http_client = self.http_client.clone(); let bot_id_owned = bot_id.to_string(); db::set_bot_status(&self.db, bot_id, "connecting", None) .await .ok(); let task = tokio::spawn(async move { let bot_id = bot_id_owned; let opts = ConnectOpts::from_addr( &bot.network_addr, &bot.nick, &bot.channels, bot.sasl_username.as_deref(), bot.sasl_password.as_deref(), ); let mut backoff = 1u64; let mut attempt = 0u32; loop { if attempt > 0 { let _ = db::set_bot_status(&db, &bot_id, "reconnecting", Some(&format!("attempt {attempt}, waiting {backoff}s"))).await; let _ = crate::db::insert_log(&db, &bot_id, "info", &format!("reconnecting in {backoff}s (attempt {attempt})")).await; tokio::select! { _ = tokio::time::sleep(std::time::Duration::from_secs(backoff)) => {} _ = stop_rx.recv() => { let _ = db::set_bot_status(&db, &bot_id, "stopped", None).await; return; } } backoff = (backoff * 2).min(60); } attempt += 1; let mut conn = match IrcConnection::connect(&opts).await { Ok(c) => { backoff = 1; c } Err(e) => { tracing::error!("bot {bot_id} connect failed: {e}"); let _ = db::set_bot_status(&db, &bot_id, "error", Some(&e.to_string())).await; continue; } }; if let Err(e) = conn.register(&opts).await { tracing::error!("bot {bot_id} registration failed: {e}"); let _ = db::set_bot_status(&db, &bot_id, "error", Some(&e.to_string())).await; continue; } backoff = 1; attempt = 0; let (write_tx, mut write_rx) = mpsc::channel::(256); let (timer_tx, mut timer_rx) = mpsc::channel::(64); let lua = match create_sandbox() { Ok(l) => l, Err(e) => { tracing::error!("bot {bot_id} lua sandbox creation failed: {e}"); let _ = db::set_bot_status(&db, &bot_id, "error", Some(&e.to_string())).await; return; } }; let irc_api = IrcApi { write_tx: write_tx.clone() }; let kv_api = KvApi { db: db.clone(), bot_id: bot_id.clone() }; let log_api = LogApi { db: db.clone(), bot_id: bot_id.clone(), log_tx: log_tx.clone() }; let http_api = HttpApi { client: http_client.clone(), rate_limiter: Arc::new(Mutex::new(HttpRateLimiter::new(10))), }; let timer_api = TimerApi { timer_tx: timer_tx.clone() }; if let Err(e) = crate::lua::api::register_apis(&lua, irc_api, kv_api, log_api, http_api, timer_api) { tracing::error!("bot {bot_id} api registration failed: {e}"); return; } for script in &scripts { if let Err(e) = lua.load(&script.source).set_name(&script.name).exec() { tracing::warn!("bot {bot_id} script {} load error: {e}", script.name); let _ = crate::db::insert_log(&db, &bot_id, "error", &format!("script {} error: {e}", script.name)).await; metrics::counter!("irc_now_bot_script_load_errors_total").increment(1); } } let _ = db::set_bot_status(&db, &bot_id, "running", None).await; let _ = crate::db::insert_log(&db, &bot_id, "info", "connected").await; let entry = LogEntry { bot_id: bot_id.clone(), level: "info".to_string(), message: "connected".to_string(), timestamp: time::OffsetDateTime::now_utc().to_string(), }; let _ = log_tx.send(entry); let (mut reader, mut writer) = conn.into_parts(); let write_task = tokio::spawn(async move { while let Some(msg) = write_rx.recv().await { let data = format!("{}\r\n", msg); if writer.write_all(data.as_bytes()).await.is_err() { break; } } }); const MAX_IRC_LINE: usize = 8192; let disconnected = loop { let mut line = String::new(); tokio::select! { result = reader.read_line(&mut line) => { match result { Ok(0) | Err(_) => break true, Ok(_) if line.len() > MAX_IRC_LINE => { tracing::warn!("bot {bot_id} received oversize IRC line ({} bytes), disconnecting", line.len()); break true; } Ok(_) => { if let Some(msg) = IrcMessage::parse(&line) { if msg.command == "PING" { let token = msg.params.first().map(|s| s.as_str()).unwrap_or(""); let _ = write_tx.send(IrcMessage::pong(token)).await; } else if let Some(event) = IrcEvent::from_message(&msg) { if let Err(e) = dispatch_event(&lua, &event) { tracing::warn!("bot {bot_id} lua error: {e}"); let _ = crate::db::insert_log( &db, &bot_id, "error", &format!("lua error in {}: {e}", event.handler_name()), ).await; } } } line.clear(); } } } Some(timer_event) = timer_rx.recv() => { let func: mlua::Function = match lua.registry_value(&timer_event.registry_key) { Ok(f) => f, Err(_) => continue, }; let _ = reset_instruction_counter(&lua); if let Err(e) = func.call::<()>(()) { tracing::warn!("bot {bot_id} timer error: {e}"); let _ = crate::db::insert_log(&db, &bot_id, "error", &format!("timer error: {e}")).await; } if timer_event.repeating { let tx = timer_tx.clone(); let secs = timer_event.interval_secs; let key = lua.create_registry_value(func).ok(); if let Some(key) = key { tokio::spawn(async move { tokio::time::sleep(std::time::Duration::from_secs_f64(secs)).await; let _ = tx.send(TimerEvent { registry_key: key, repeating: true, interval_secs: secs, }).await; }); } } else { lua.remove_registry_value(timer_event.registry_key).ok(); } } Some((target, text)) = msg_rx.recv() => { let _ = write_tx.send(IrcMessage::privmsg(&target, &text)).await; } _ = stop_rx.recv() => { break false; } } }; write_task.abort(); if !disconnected { let _ = db::set_bot_status(&db, &bot_id, "stopped", None).await; let _ = crate::db::insert_log(&db, &bot_id, "info", "stopped").await; return; } let _ = crate::db::insert_log(&db, &bot_id, "warn", "disconnected, will reconnect").await; } }); { let mut bots = self.bots.write().await; bots.insert(bot_id.to_string(), BotHandle { stop_tx, msg_tx, task }); } Ok(()) } pub async fn send_message(&self, bot_id: &str, target: &str, text: &str) -> Result<(), String> { let bots = self.bots.read().await; let handle = bots.get(bot_id).ok_or("bot not running")?; handle .msg_tx .send((target.to_string(), text.to_string())) .await .map_err(|_| "send failed".to_string()) } pub async fn stop_bot(&self, bot_id: &str) -> Result<(), String> { let handle = { let mut bots = self.bots.write().await; bots.remove(bot_id) }; match handle { Some(h) => { let _ = h.stop_tx.send(()).await; h.task.abort(); db::set_bot_status(&self.db, bot_id, "stopped", None) .await .ok(); Ok(()) } None => Err("bot is not running".to_string()), } } pub async fn start_all_enabled(self: Arc) { let bots = match db::list_enabled_bots(&self.db).await { Ok(b) => b, Err(e) => { tracing::error!("failed to list enabled bots: {e}"); return; } }; for bot in bots { let jitter = rand::random::() % 30_000; let bot_id = bot.id.clone(); let mgr = self.clone(); tokio::spawn(async move { tokio::time::sleep(std::time::Duration::from_millis(jitter)).await; if let Err(e) = mgr.start_bot(&bot_id).await { tracing::error!("failed to auto-start bot {bot_id}: {e}"); } }); } } }