use crate::security::command_guard::SafeCommand; use crate::shared::state::AppState; use chrono::{DateTime, Utc}; use diesel::prelude::*; use log::info; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpServer { pub id: String, pub name: String, pub description: String, pub server_type: McpServerType, pub connection: McpConnection, pub auth: McpAuth, pub tools: Vec, pub capabilities: McpCapabilities, pub status: McpServerStatus, pub bot_id: String, pub created_at: DateTime, pub updated_at: DateTime, pub last_health_check: Option>, pub health_status: HealthStatus, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum McpServerType { Database, Filesystem, Web, Email, Slack, Teams, Analytics, Search, Storage, Compute, Custom(String), } impl Default for McpServerType { fn default() -> Self { Self::Custom("unknown".to_string()) } } impl From<&str> for McpServerType { fn from(s: &str) -> Self { match s.to_lowercase().as_str() { "database" | "db" => Self::Database, "filesystem" | "fs" | "file" => Self::Filesystem, "web" | "http" | "rest" | "api" => Self::Web, "email" | "mail" | "smtp" | "imap" => Self::Email, "slack" => Self::Slack, "teams" | "microsoft-teams" => Self::Teams, "analytics" | "data" => Self::Analytics, "search" | "elasticsearch" | "opensearch" => Self::Search, "storage" | "s3" | "blob" | "gcs" => Self::Storage, "compute" | "lambda" | "function" => Self::Compute, other => Self::Custom(other.to_string()), } } } impl std::fmt::Display for McpServerType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Database => write!(f, "database"), Self::Filesystem => write!(f, "filesystem"), Self::Web => write!(f, "web"), Self::Email => write!(f, "email"), Self::Slack => write!(f, "slack"), Self::Teams => write!(f, "teams"), Self::Analytics => write!(f, "analytics"), Self::Search => write!(f, "search"), Self::Storage => write!(f, "storage"), Self::Compute => write!(f, "compute"), Self::Custom(s) => write!(f, "{s}"), } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpConnection { pub connection_type: ConnectionType, pub url: String, pub port: Option, pub timeout_seconds: i32, pub max_retries: i32, pub retry_backoff_ms: i32, pub keep_alive: bool, pub tls_config: Option, } impl Default for McpConnection { fn default() -> Self { Self { connection_type: ConnectionType::Http, url: "http://localhost:8080".to_string(), port: None, timeout_seconds: 30, max_retries: 3, retry_backoff_ms: 1000, keep_alive: true, tls_config: None, } } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Default)] pub enum ConnectionType { #[default] Http, WebSocket, Grpc, UnixSocket, Stdio, Tcp, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TlsConfig { pub enabled: bool, pub verify_certificates: bool, pub ca_cert_path: Option, pub client_cert_path: Option, pub client_key_path: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpAuth { pub auth_type: McpAuthType, pub credentials: McpCredentials, } impl Default for McpAuth { fn default() -> Self { Self { auth_type: McpAuthType::None, credentials: McpCredentials::None, } } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Default)] pub enum McpAuthType { #[default] None, ApiKey, Bearer, Basic, OAuth2, Certificate, Custom(String), } #[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Default)] pub enum McpCredentials { #[default] None, ApiKey { header_name: String, key_ref: String, }, Bearer { token_ref: String, }, Basic { username_ref: String, password_ref: String, }, OAuth2 { client_id_ref: String, client_secret_ref: String, token_url: String, scopes: Vec, }, Certificate { cert_ref: String, key_ref: String, }, Custom(HashMap), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpTool { pub name: String, pub description: String, pub input_schema: serde_json::Value, pub output_schema: Option, pub required_permissions: Vec, pub risk_level: ToolRiskLevel, pub is_destructive: bool, pub requires_approval: bool, pub rate_limit: Option, pub timeout_seconds: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Default)] pub enum ToolRiskLevel { Safe, #[default] Low, Medium, High, Critical, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct McpCapabilities { pub tools: bool, pub resources: bool, pub prompts: bool, pub logging: bool, pub streaming: bool, pub cancellation: bool, pub custom: HashMap, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Default)] pub enum McpServerStatus { Active, #[default] Inactive, Connecting, Error(String), Maintenance, Unknown, } #[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Default)] pub struct HealthStatus { pub healthy: bool, pub last_check: Option>, pub response_time_ms: Option, pub error_message: Option, pub consecutive_failures: i32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpRequest { pub id: String, pub server: String, pub tool: String, pub arguments: serde_json::Value, pub context: McpRequestContext, pub timeout_seconds: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpRequestContext { pub session_id: String, pub bot_id: String, pub user_id: String, pub task_id: Option, pub step_id: Option, pub correlation_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpResponse { pub id: String, pub success: bool, pub result: Option, pub error: Option, pub metadata: McpResponseMetadata, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpError { pub code: String, pub message: String, pub details: Option, pub retryable: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpResponseMetadata { pub duration_ms: i64, pub server_version: Option, pub rate_limit_remaining: Option, pub rate_limit_reset: Option>, } pub struct McpClient { state: Arc, config: McpClientConfig, servers: HashMap, http_client: reqwest::Client, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpClientConfig { pub enabled: bool, pub default_timeout_seconds: i32, pub max_concurrent_requests: i32, pub cache_enabled: bool, pub cache_ttl_seconds: i32, pub audit_enabled: bool, pub health_check_interval_seconds: i32, pub auto_retry: bool, pub circuit_breaker_threshold: i32, } impl Default for McpClientConfig { fn default() -> Self { Self { enabled: true, default_timeout_seconds: 30, max_concurrent_requests: 10, cache_enabled: true, cache_ttl_seconds: 300, audit_enabled: true, health_check_interval_seconds: 60, auto_retry: true, circuit_breaker_threshold: 5, } } } impl std::fmt::Debug for McpClient { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("McpClient") .field("config", &self.config) .field("servers_count", &self.servers.len()) .finish_non_exhaustive() } } impl McpClient { pub fn new(state: Arc) -> Self { let http_client = reqwest::Client::builder() .timeout(Duration::from_secs(30)) .build() .unwrap_or_default(); Self { state, config: McpClientConfig::default(), servers: HashMap::new(), http_client, } } pub fn with_config(state: Arc, config: McpClientConfig) -> Self { let http_client = reqwest::Client::builder() .timeout(Duration::from_secs(config.default_timeout_seconds as u64)) .build() .unwrap_or_default(); Self { state, config, servers: HashMap::new(), http_client, } } pub fn load_servers( &mut self, bot_id: &Uuid, ) -> Result<(), Box> { let mut conn = self .state .conn .get() .map_err(|e| format!("DB error: {}", e))?; let bot_id_str = bot_id.to_string(); let query = diesel::sql_query( "SELECT id, name, description, server_type, config, status, created_at, updated_at FROM mcp_servers WHERE bot_id = $1 AND status != 'deleted'", ) .bind::(&bot_id_str); #[derive(QueryableByName)] struct ServerRow { #[diesel(sql_type = diesel::sql_types::Text)] id: String, #[diesel(sql_type = diesel::sql_types::Text)] name: String, #[diesel(sql_type = diesel::sql_types::Nullable)] description: Option, #[diesel(sql_type = diesel::sql_types::Text)] server_type: String, #[diesel(sql_type = diesel::sql_types::Text)] config: String, #[diesel(sql_type = diesel::sql_types::Text)] status: String, } let rows: Vec = query.load(&mut *conn).unwrap_or_default(); for row in rows { let server = McpServer { id: row.id.clone(), name: row.name.clone(), description: row.description.unwrap_or_default(), server_type: McpServerType::from(row.server_type.as_str()), connection: serde_json::from_str(&row.config).unwrap_or_default(), auth: McpAuth::default(), tools: Vec::new(), capabilities: McpCapabilities::default(), status: match row.status.as_str() { "active" => McpServerStatus::Active, "maintenance" => McpServerStatus::Maintenance, "error" => McpServerStatus::Error("Unknown error".to_string()), _ => McpServerStatus::Inactive, }, bot_id: bot_id_str.clone(), created_at: Utc::now(), updated_at: Utc::now(), last_health_check: None, health_status: HealthStatus::default(), }; self.servers.insert(row.name, server); } info!( "Loaded {} MCP servers for bot {}", self.servers.len(), bot_id ); Ok(()) } pub fn register_server( &mut self, server: McpServer, ) -> Result<(), Box> { let mut conn = self .state .conn .get() .map_err(|e| format!("DB error: {}", e))?; let config_json = serde_json::to_string(&server.connection)?; let now = Utc::now().to_rfc3339(); let server_type_str = server.server_type.to_string(); let query = diesel::sql_query( "INSERT INTO mcp_servers (id, bot_id, name, description, server_type, config, status, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (bot_id, name) DO UPDATE SET description = EXCLUDED.description, server_type = EXCLUDED.server_type, config = EXCLUDED.config, status = EXCLUDED.status, updated_at = EXCLUDED.updated_at" ) .bind::(&server.id) .bind::(&server.bot_id) .bind::(&server.name) .bind::(&server.description) .bind::(&server_type_str) .bind::(&config_json) .bind::("active") .bind::(&now) .bind::(&now); query .execute(&mut *conn) .map_err(|e| format!("Failed to register MCP server: {}", e))?; self.servers.insert(server.name.clone(), server); Ok(()) } pub fn get_server(&self, name: &str) -> Option<&McpServer> { self.servers.get(name) } pub fn list_servers(&self) -> Vec<&McpServer> { self.servers.values().collect() } pub async fn list_tools( &self, server_name: &str, ) -> Result, Box> { let server = self .servers .get(server_name) .ok_or_else(|| format!("MCP server '{}' not found", server_name))?; if server.connection.connection_type == ConnectionType::Http { let url = format!("{}/tools/list", server.connection.url); let response = self .http_client .get(&url) .timeout(Duration::from_secs( server.connection.timeout_seconds as u64, )) .send() .await?; if response.status().is_success() { let tools: Vec = response.json().await?; return Ok(tools); } } Ok(server.tools.clone()) } pub async fn invoke_tool( &self, request: McpRequest, ) -> Result> { let start_time = std::time::Instant::now(); let server = self .servers .get(&request.server) .ok_or_else(|| format!("MCP server '{}' not found", request.server))?; if server.status != McpServerStatus::Active { return Ok(McpResponse { id: request.id, success: false, result: None, error: Some(McpError { code: "SERVER_UNAVAILABLE".to_string(), message: format!( "MCP server '{}' is not active (status: {:?})", request.server, server.status ), details: None, retryable: true, }), metadata: McpResponseMetadata { duration_ms: start_time.elapsed().as_millis() as i64, server_version: None, rate_limit_remaining: None, rate_limit_reset: None, }, }); } if self.config.audit_enabled { info!( "MCP request: server={} tool={}", request.server, request.tool ); } let result = match server.connection.connection_type { ConnectionType::Http => self.invoke_http(server, &request).await, ConnectionType::Stdio => self.invoke_stdio(server, &request).await, _ => Err(format!( "Connection type {:?} not yet supported", server.connection.connection_type ) .into()), }; let duration_ms = start_time.elapsed().as_millis() as i64; match result { Ok(mut response) => { response.metadata.duration_ms = duration_ms; if self.config.audit_enabled { info!( "MCP response: id={} success={}", response.id, response.success ); } Ok(response) } Err(e) => { let response = McpResponse { id: request.id.clone(), success: false, result: None, error: Some(McpError { code: "INVOCATION_ERROR".to_string(), message: e.to_string(), details: None, retryable: true, }), metadata: McpResponseMetadata { duration_ms, server_version: None, rate_limit_remaining: None, rate_limit_reset: None, }, }; if self.config.audit_enabled { info!( "MCP error response: id={} error={:?}", response.id, response.error ); } Ok(response) } } } async fn invoke_http( &self, server: &McpServer, request: &McpRequest, ) -> Result> { let url = format!("{}/tools/call", server.connection.url); let body = serde_json::json!({ "name": request.tool, "arguments": request.arguments }); let timeout = request .timeout_seconds .unwrap_or(server.connection.timeout_seconds); let mut http_request = self .http_client .post(&url) .json(&body) .timeout(Duration::from_secs(timeout as u64)); http_request = Self::add_auth_headers(http_request, &server.auth); let response = http_request.send().await?; let status = response.status(); if status.is_success() { let result: serde_json::Value = response.json().await?; Ok(McpResponse { id: request.id.clone(), success: true, result: Some(result), error: None, metadata: McpResponseMetadata { duration_ms: 0, server_version: None, rate_limit_remaining: None, rate_limit_reset: None, }, }) } else { let error_text = response.text().await.unwrap_or_default(); Ok(McpResponse { id: request.id.clone(), success: false, result: None, error: Some(McpError { code: format!("HTTP_{}", status.as_u16()), message: error_text, details: None, retryable: status.as_u16() >= 500, }), metadata: McpResponseMetadata { duration_ms: 0, server_version: None, rate_limit_remaining: None, rate_limit_reset: None, }, }) } } async fn invoke_stdio( &self, server: &McpServer, request: &McpRequest, ) -> Result> { let _input = serde_json::json!({ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": request.tool, "arguments": request.arguments }, "id": request.id }); let cmd = SafeCommand::new(&server.connection.url) .map_err(|e| anyhow::anyhow!("Failed to build MCP command: {}", e))?; let output = cmd.execute_async() .await .map_err(|e| anyhow::anyhow!("Failed to execute MCP command: {}", e))?; if output.status.success() { let result: serde_json::Value = serde_json::from_slice(&output.stdout)?; Ok(McpResponse { id: request.id.clone(), success: true, result: result.get("result").cloned(), error: None, metadata: McpResponseMetadata { duration_ms: 0, server_version: None, rate_limit_remaining: None, rate_limit_reset: None, }, }) } else { let stderr = String::from_utf8_lossy(&output.stderr); Ok(McpResponse { id: request.id.clone(), success: false, result: None, error: Some(McpError { code: "STDIO_ERROR".to_string(), message: stderr.to_string(), details: None, retryable: false, }), metadata: McpResponseMetadata { duration_ms: 0, server_version: None, rate_limit_remaining: None, rate_limit_reset: None, }, }) } } fn add_auth_headers( mut request: reqwest::RequestBuilder, auth: &McpAuth, ) -> reqwest::RequestBuilder { match &auth.credentials { McpCredentials::ApiKey { header_name, key_ref, } => { request = request.header(header_name.as_str(), key_ref.as_str()); } McpCredentials::Bearer { token_ref } => { request = request.bearer_auth(token_ref); } McpCredentials::Basic { username_ref, password_ref, } => { request = request.basic_auth(username_ref, Some(password_ref)); } _ => {} } request } pub async fn health_check( &mut self, server_name: &str, ) -> Result> { let server = self .servers .get_mut(server_name) .ok_or_else(|| format!("MCP server '{}' not found", server_name))?; let start_time = std::time::Instant::now(); let health_url = format!("{}/health", server.connection.url); let result = self .http_client .get(&health_url) .timeout(Duration::from_secs(5)) .send() .await; let latency_ms = start_time.elapsed().as_millis() as i64; match result { Ok(response) => { if response.status().is_success() { server.status = McpServerStatus::Active; Ok(HealthStatus { healthy: true, last_check: Some(Utc::now()), response_time_ms: Some(latency_ms), error_message: None, consecutive_failures: 0, }) } else { server.status = McpServerStatus::Error(format!("HTTP {}", response.status())); Ok(HealthStatus { healthy: false, last_check: Some(Utc::now()), response_time_ms: Some(latency_ms), error_message: Some(format!( "Server returned status {}", response.status() )), consecutive_failures: 1, }) } } Err(e) => { server.status = McpServerStatus::Unknown; Ok(HealthStatus { healthy: false, last_check: Some(Utc::now()), response_time_ms: Some(latency_ms), error_message: Some(format!("Health check failed: {}", e)), consecutive_failures: 1, }) } } } }