use crate::auto_task::TaskManifest; use crate::core::bot::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter}; use crate::core::bot_database::BotDatabaseManager; use crate::core::config::AppConfig; #[cfg(any(feature = "research", feature = "llm"))] use crate::core::kb::KnowledgeBaseManager; use crate::core::session::SessionManager; use crate::core::shared::analytics::MetricsCollector; #[cfg(all(test, feature = "directory"))] use crate::core::shared::test_utils::create_mock_auth_service; #[cfg(all(test, feature = "llm"))] use crate::core::shared::test_utils::MockLLMProvider; #[cfg(feature = "directory")] use crate::directory::AuthService; #[cfg(feature = "compliance")] use crate::legal::LegalService; #[cfg(feature = "llm")] use crate::llm::{DynamicLLMProvider, LLMProvider}; #[cfg(feature = "project")] use crate::project::ProjectService; use crate::security::auth_provider::AuthProviderRegistry; use crate::security::jwt::JwtManager; use crate::security::rbac_middleware::RbacManager; use crate::core::shared::models::BotResponse; use crate::core::shared::utils::DbPool; #[cfg(feature = "tasks")] use crate::tasks::{TaskEngine, TaskScheduler}; #[cfg(feature = "drive")] use aws_sdk_s3::Client as S3Client; #[cfg(test)] use diesel::r2d2::{ConnectionManager, Pool}; #[cfg(test)] use diesel::PgConnection; #[cfg(feature = "cache")] use redis::Client as RedisClient; use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{broadcast, mpsc, RwLock}; #[cfg(not(feature = "drive"))] #[derive(Debug, Clone, Default)] pub struct NoDrive; #[cfg(not(feature = "drive"))] impl NoDrive { pub fn new() -> Self { NoDrive } } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct AttendantNotification { #[serde(rename = "type")] pub notification_type: String, pub session_id: String, pub user_id: String, pub user_name: Option, pub user_phone: Option, pub channel: String, pub content: String, pub timestamp: String, pub assigned_to: Option, pub priority: i32, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct AgentActivity { pub phase: String, pub items_processed: u32, #[serde(skip_serializing_if = "Option::is_none")] pub items_total: Option, #[serde(skip_serializing_if = "Option::is_none")] pub speed_per_min: Option, #[serde(skip_serializing_if = "Option::is_none")] pub eta_seconds: Option, #[serde(skip_serializing_if = "Option::is_none")] pub current_item: Option, #[serde(skip_serializing_if = "Option::is_none")] pub bytes_processed: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tokens_used: Option, #[serde(skip_serializing_if = "Option::is_none")] pub files_created: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tables_created: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub log_lines: Option>, } impl AgentActivity { pub fn new(phase: impl Into) -> Self { Self { phase: phase.into(), items_processed: 0, items_total: None, speed_per_min: None, eta_seconds: None, current_item: None, bytes_processed: None, tokens_used: None, files_created: None, tables_created: None, log_lines: None, } } #[must_use] pub fn with_progress(mut self, processed: u32, total: Option) -> Self { self.items_processed = processed; self.items_total = total; self } #[must_use] pub fn with_speed(mut self, speed: f32, eta_seconds: Option) -> Self { self.speed_per_min = Some(speed); self.eta_seconds = eta_seconds; self } #[must_use] pub fn with_current_item(mut self, item: impl Into) -> Self { self.current_item = Some(item.into()); self } #[must_use] pub fn with_bytes(mut self, bytes: u64) -> Self { self.bytes_processed = Some(bytes); self } #[must_use] pub fn with_tokens(mut self, tokens: u32) -> Self { self.tokens_used = Some(tokens); self } #[must_use] pub fn with_files(mut self, files: Vec) -> Self { self.files_created = Some(files); self } #[must_use] pub fn with_tables(mut self, tables: Vec) -> Self { self.tables_created = Some(tables); self } #[must_use] pub fn with_log_lines(mut self, lines: Vec) -> Self { self.log_lines = Some(lines); self } #[must_use] pub fn add_log_line(mut self, line: impl Into) -> Self { let lines = self.log_lines.get_or_insert_with(Vec::new); lines.push(line.into()); self } } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct TaskProgressEvent { #[serde(rename = "type")] pub event_type: String, pub task_id: String, pub step: String, pub message: String, pub progress: u8, pub total_steps: u8, pub current_step: u8, pub timestamp: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, #[serde(skip_serializing_if = "Option::is_none")] pub activity: Option, #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, } impl TaskProgressEvent { pub fn new( task_id: impl Into, step: impl Into, message: impl Into, ) -> Self { Self { event_type: "task_progress".to_string(), task_id: task_id.into(), step: step.into(), message: message.into(), progress: 0, total_steps: 0, current_step: 0, timestamp: chrono::Utc::now().to_rfc3339(), details: None, error: None, activity: None, text: None, } } pub fn llm_stream(task_id: impl Into, text: impl Into) -> Self { Self { event_type: "llm_stream".to_string(), task_id: task_id.into(), step: "llm_stream".to_string(), message: String::new(), progress: 0, total_steps: 0, current_step: 0, timestamp: chrono::Utc::now().to_rfc3339(), details: None, error: None, activity: None, text: Some(text.into()), } } #[must_use] pub fn with_progress(mut self, current: u8, total: u8) -> Self { self.current_step = current; self.total_steps = total; self.progress = if total > 0 { ((current as u16 * 100) / total as u16) as u8 } else { 0 }; self } #[must_use] pub fn with_details(mut self, details: impl Into) -> Self { self.details = Some(details.into()); self } #[must_use] pub fn with_activity(mut self, activity: AgentActivity) -> Self { self.activity = Some(activity); self } #[must_use] pub fn with_event_type(mut self, event_type: impl Into) -> Self { self.event_type = event_type.into(); self } #[must_use] pub fn with_error(mut self, error: impl Into) -> Self { self.event_type = "task_error".to_string(); self.error = Some(error.into()); self } #[must_use] pub fn completed(mut self) -> Self { self.event_type = "task_completed".to_string(); self.progress = 100; self } pub fn started( task_id: impl Into, message: impl Into, total_steps: u8, ) -> Self { Self { event_type: "task_started".to_string(), task_id: task_id.into(), step: "init".to_string(), message: message.into(), progress: 0, total_steps, current_step: 0, timestamp: chrono::Utc::now().to_rfc3339(), details: None, error: None, activity: None, text: None, } } } #[derive(Clone, Default)] pub struct Extensions { map: Arc>>>, } impl Extensions { #[must_use] pub fn new() -> Self { Self { map: Arc::new(RwLock::new(HashMap::new())), } } pub async fn insert(&self, value: T) { let mut map = self.map.write().await; map.insert(TypeId::of::(), Arc::new(value)); } pub fn insert_blocking(&self, value: T) { let map = self.map.clone(); let (tx, rx) = std::sync::mpsc::channel(); std::thread::spawn(move || { let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build(); if let Ok(rt) = rt { rt.block_on(async { let mut guard = map.write().await; guard.insert(TypeId::of::(), Arc::new(value)); }); } let _ = tx.send(()); }); let _ = rx.recv(); } pub async fn get(&self) -> Option> { let map = self.map.read().await; map.get(&TypeId::of::()) .and_then(|boxed| Arc::clone(boxed).downcast::().ok()) } pub async fn contains(&self) -> bool { let map = self.map.read().await; map.contains_key(&TypeId::of::()) } pub async fn remove(&self) -> Option> { let mut map = self.map.write().await; map.remove(&TypeId::of::()) .and_then(|boxed| boxed.downcast::().ok()) } pub async fn len(&self) -> usize { let map = self.map.read().await; map.len() } pub async fn is_empty(&self) -> bool { let map = self.map.read().await; map.is_empty() } } impl std::fmt::Debug for Extensions { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Extensions").finish_non_exhaustive() } } /// Billing alert notification for WebSocket broadcast #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct BillingAlertNotification { pub alert_id: uuid::Uuid, pub organization_id: uuid::Uuid, pub severity: String, pub alert_type: String, pub title: String, pub message: String, pub metric: String, pub percentage: f64, pub triggered_at: chrono::DateTime, } pub struct AppState { #[cfg(feature = "drive")] pub drive: Option, #[cfg(not(feature = "drive"))] #[allow(non_snake_case)] pub drive: Option, #[cfg(feature = "cache")] pub cache: Option>, pub bucket_name: String, pub config: Option, pub conn: DbPool, pub database_url: String, pub bot_database_manager: Arc, pub session_manager: Arc>, pub metrics_collector: MetricsCollector, #[cfg(feature = "tasks")] pub task_scheduler: Option>, #[cfg(feature = "llm")] pub llm_provider: Arc, #[cfg(feature = "llm")] pub dynamic_llm_provider: Option>, #[cfg(feature = "directory")] pub auth_service: Arc>, pub channels: Arc>>>, pub response_channels: Arc>>>, /// Active streaming sessions for cancellation: session_id → cancellation sender pub active_streams: Arc>>>, /// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver. pub hear_channels: Arc>>>, pub web_adapter: Arc, pub voice_adapter: Arc, #[cfg(any(feature = "research", feature = "llm"))] pub kb_manager: Option>, #[cfg(feature = "tasks")] pub task_engine: Arc, pub extensions: Extensions, pub attendant_broadcast: Option>, pub task_progress_broadcast: Option>, pub billing_alert_broadcast: Option>, pub task_manifests: Arc>>, #[cfg(feature = "terminal")] pub terminal_manager: Arc, #[cfg(feature = "project")] pub project_service: Arc>, #[cfg(feature = "compliance")] pub legal_service: Arc>, pub jwt_manager: Option>, pub auth_provider_registry: Option>, pub rbac_manager: Option>, } impl Clone for AppState { fn clone(&self) -> Self { Self { #[cfg(feature = "drive")] drive: self.drive.clone(), #[cfg(not(feature = "drive"))] drive: None, bucket_name: self.bucket_name.clone(), config: self.config.clone(), conn: self.conn.clone(), database_url: self.database_url.clone(), bot_database_manager: Arc::clone(&self.bot_database_manager), #[cfg(feature = "cache")] cache: self.cache.clone(), session_manager: Arc::clone(&self.session_manager), metrics_collector: self.metrics_collector.clone(), #[cfg(feature = "tasks")] task_scheduler: self.task_scheduler.clone(), #[cfg(feature = "llm")] llm_provider: Arc::clone(&self.llm_provider), #[cfg(feature = "llm")] dynamic_llm_provider: self.dynamic_llm_provider.clone(), #[cfg(feature = "directory")] auth_service: Arc::clone(&self.auth_service), #[cfg(any(feature = "research", feature = "llm"))] kb_manager: self.kb_manager.clone(), channels: Arc::clone(&self.channels), response_channels: Arc::clone(&self.response_channels), active_streams: Arc::clone(&self.active_streams), hear_channels: Arc::clone(&self.hear_channels), web_adapter: Arc::clone(&self.web_adapter), voice_adapter: Arc::clone(&self.voice_adapter), #[cfg(feature = "tasks")] task_engine: Arc::clone(&self.task_engine), extensions: self.extensions.clone(), attendant_broadcast: self.attendant_broadcast.clone(), task_progress_broadcast: self.task_progress_broadcast.clone(), billing_alert_broadcast: self.billing_alert_broadcast.clone(), task_manifests: Arc::clone(&self.task_manifests), #[cfg(feature = "terminal")] terminal_manager: Arc::clone(&self.terminal_manager), #[cfg(feature = "project")] project_service: Arc::clone(&self.project_service), #[cfg(feature = "compliance")] legal_service: Arc::clone(&self.legal_service), jwt_manager: self.jwt_manager.clone(), auth_provider_registry: self.auth_provider_registry.clone(), rbac_manager: self.rbac_manager.clone(), } } } impl std::fmt::Debug for AppState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut debug = f.debug_struct("AppState"); #[cfg(feature = "drive")] debug.field("drive", &self.drive.is_some()); #[cfg(feature = "cache")] debug.field("cache", &self.cache.is_some()); debug .field("bucket_name", &self.bucket_name) .field("config", &self.config.is_some()) .field("conn", &"DbPool") .field("database_url", &"[REDACTED]") .field("bot_database_manager", &"Arc") .field("session_manager", &"Arc>") .field("metrics_collector", &"MetricsCollector"); #[cfg(any(feature = "research", feature = "llm"))] debug.field("kb_manager", &self.kb_manager.is_some()); #[cfg(feature = "tasks")] debug.field("task_scheduler", &self.task_scheduler.is_some()); #[cfg(feature = "llm")] debug.field("llm_provider", &"Arc"); #[cfg(feature = "directory")] debug.field("auth_service", &"Arc>"); debug .field("channels", &"Arc>") .field("response_channels", &"Arc>") .field("web_adapter", &self.web_adapter) .field("voice_adapter", &self.voice_adapter); #[cfg(any(feature = "research", feature = "llm"))] debug.field("kb_manager", &self.kb_manager.is_some()); #[cfg(feature = "tasks")] debug.field("task_engine", &"Arc"); debug .field("extensions", &self.extensions) .field("attendant_broadcast", &self.attendant_broadcast.is_some()) .field( "task_progress_broadcast", &self.task_progress_broadcast.is_some(), ) .field("jwt_manager", &self.jwt_manager.is_some()) .field( "auth_provider_registry", &self.auth_provider_registry.is_some(), ) .field("rbac_manager", &self.rbac_manager.is_some()) .finish() } } impl AppState { pub fn broadcast_task_progress(&self, event: TaskProgressEvent) { log::info!( "Broadcasting: task_id={}, step={}, message={}", event.task_id, event.step, event.message ); if let Some(tx) = &self.task_progress_broadcast { let receiver_count = tx.receiver_count(); log::info!( "Broadcast channel has {} receivers", receiver_count ); match tx.send(event) { Ok(_) => { log::info!("Event sent successfully"); } Err(e) => { log::warn!("No listeners for task progress: {e}"); } } } else { log::warn!("No broadcast channel configured!"); } } pub fn emit_progress(&self, task_id: &str, step: &str, message: &str, current: u8, total: u8) { let event = TaskProgressEvent::new(task_id, step, message).with_progress(current, total); self.broadcast_task_progress(event); } pub fn emit_progress_with_details( &self, task_id: &str, step: &str, message: &str, current: u8, total: u8, details: &str, ) { let event = TaskProgressEvent::new(task_id, step, message) .with_progress(current, total) .with_details(details); self.broadcast_task_progress(event); } pub fn emit_activity( &self, task_id: &str, step: &str, message: &str, current: u8, total: u8, activity: AgentActivity, ) { let event = TaskProgressEvent::new(task_id, step, message) .with_progress(current, total) .with_activity(activity); self.broadcast_task_progress(event); } pub fn emit_task_started(&self, task_id: &str, message: &str, total_steps: u8) { let event = TaskProgressEvent::started(task_id, message, total_steps); self.broadcast_task_progress(event); } pub fn emit_task_completed(&self, task_id: &str, message: &str) { let event = TaskProgressEvent::new(task_id, "complete", message).completed(); self.broadcast_task_progress(event); } pub fn emit_task_error(&self, task_id: &str, step: &str, error: &str) { let event = TaskProgressEvent::new(task_id, step, "Task failed").with_error(error); self.broadcast_task_progress(event); } pub fn emit_llm_stream(&self, task_id: &str, text: &str) { let event = TaskProgressEvent::llm_stream(task_id, text); if let Some(tx) = &self.task_progress_broadcast { // Don't log every stream chunk - too noisy let _ = tx.send(event); } } } #[cfg(test)] impl Default for AppState { fn default() -> Self { let database_url = crate::core::shared::utils::get_database_url_sync() .expect("AppState::default() requires Vault to be configured"); let manager = ConnectionManager::::new(&database_url); let pool = Pool::builder() .max_size(1) .test_on_check_out(false) .connection_timeout(std::time::Duration::from_secs(5)) .build(manager) .expect("Failed to create test database pool"); let conn = pool.get().expect("Failed to get test database connection"); let session_manager = SessionManager::new(conn, None); let (attendant_tx, _) = broadcast::channel(100); let (task_progress_tx, _) = broadcast::channel(100); let bot_database_manager = Arc::new(BotDatabaseManager::new(pool.clone(), &database_url)); Self { #[cfg(feature = "drive")] drive: None, #[cfg(not(feature = "drive"))] drive: None, #[cfg(feature = "cache")] cache: None, bucket_name: "test-bucket".to_string(), config: None, conn: pool.clone(), database_url, bot_database_manager, session_manager: Arc::new(tokio::sync::Mutex::new(session_manager)), metrics_collector: MetricsCollector::new(), #[cfg(feature = "tasks")] task_scheduler: None, #[cfg(all(test, feature = "llm"))] llm_provider: Arc::new(MockLLMProvider::new()), #[cfg(feature = "llm")] dynamic_llm_provider: None, #[cfg(feature = "directory")] auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())), channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())), hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())), web_adapter: Arc::new(WebChannelAdapter::new()), voice_adapter: Arc::new(VoiceAdapter::new()), #[cfg(any(feature = "research", feature = "llm"))] kb_manager: None, #[cfg(feature = "tasks")] task_engine: Arc::new(TaskEngine::new(pool)), extensions: Extensions::new(), attendant_broadcast: Some(attendant_tx), task_progress_broadcast: Some(task_progress_tx), billing_alert_broadcast: None, task_manifests: Arc::new(std::sync::RwLock::new(HashMap::new())), #[cfg(feature = "project")] project_service: Arc::new(RwLock::new(crate::project::ProjectService::new())), #[cfg(feature = "compliance")] legal_service: Arc::new(RwLock::new(crate::legal::LegalService::new())), jwt_manager: None, auth_provider_registry: None, rbac_manager: None, } } }