generalbots/src/billing/middleware.rs
Rodrigo Rodriguez (Pragmatismo) faeae250bc Add security protection module with sudo-based privilege escalation
- Create installer.rs for 'botserver install protection' command
- Requires root to install packages and create sudoers config
- Sudoers uses exact commands (no wildcards) for security
- Update all tool files (lynis, rkhunter, chkrootkit, suricata, lmd) to use sudo
- Update manager.rs service management to use sudo
- Add 'sudo' and 'visudo' to command_guard.rs whitelist
- Update CLI with install/remove/status protection commands

Security model:
- Installation requires root (sudo botserver install protection)
- Runtime uses sudoers NOPASSWD for specific commands only
- No wildcards in sudoers - exact command specifications
- Tools run on host system, not in containers
2026-01-10 09:41:12 -03:00

494 lines
14 KiB
Rust

use axum::{
body::Body,
extract::State,
http::{header::HeaderValue, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::billing::quotas::{QuotaAction, QuotaManager};
use crate::billing::UsageMetric;
#[derive(Clone)]
pub struct QuotaMiddlewareState {
pub quota_manager: Arc<QuotaManager>,
pub enabled: Arc<RwLock<bool>>,
}
impl QuotaMiddlewareState {
pub fn new(quota_manager: Arc<QuotaManager>) -> Self {
Self {
quota_manager,
enabled: Arc::new(RwLock::new(true)),
}
}
pub async fn set_enabled(&self, enabled: bool) {
let mut guard = self.enabled.write().await;
*guard = enabled;
}
pub async fn is_enabled(&self) -> bool {
*self.enabled.read().await
}
}
pub async fn quota_middleware(
State(state): State<Arc<QuotaMiddlewareState>>,
request: Request<Body>,
next: Next,
) -> Response {
if !state.is_enabled().await {
return next.run(request).await;
}
let org_id = extract_organization_id(&request);
let org_id = match org_id {
Some(id) => id,
None => return next.run(request).await,
};
let metric = determine_metric_from_request(&request);
let action = state.quota_manager.check_action(org_id, metric).await;
match action {
QuotaAction::Allow => {
let response = next.run(request).await;
if response.status().is_success() {
if let Err(e) = state
.quota_manager
.increment_usage(org_id, metric, 1)
.await
{
tracing::warn!("Failed to increment usage for org {}: {}", org_id, e);
}
}
response
}
QuotaAction::Warn { message, percentage } => {
let mut response = next.run(request).await;
if response.status().is_success() {
if let Err(e) = state
.quota_manager
.increment_usage(org_id, metric, 1)
.await
{
tracing::warn!("Failed to increment usage for org {}: {}", org_id, e);
}
}
let headers = response.headers_mut();
headers.insert(
"X-Quota-Warning",
message.parse().unwrap_or_else(|_| HeaderValue::from_static("quota warning")),
);
headers.insert(
"X-Quota-Usage-Percent",
percentage
.to_string()
.parse()
.unwrap_or_else(|_| HeaderValue::from_static("0")),
);
response
}
QuotaAction::Block { message } => QuotaExceededResponse { message }.into_response(),
}
}
pub async fn api_rate_limit_middleware(
State(state): State<Arc<QuotaMiddlewareState>>,
request: Request<Body>,
next: Next,
) -> Response {
if !state.is_enabled().await {
return next.run(request).await;
}
let org_id = match extract_organization_id(&request) {
Some(id) => id,
None => return next.run(request).await,
};
let action = state
.quota_manager
.check_action(org_id, UsageMetric::ApiCalls)
.await;
match action {
QuotaAction::Allow => {
if let Err(e) = state
.quota_manager
.increment_usage(org_id, UsageMetric::ApiCalls, 1)
.await
{
tracing::warn!("Failed to increment API call count for org {}: {}", org_id, e);
}
next.run(request).await
}
QuotaAction::Warn { message, percentage } => {
if let Err(e) = state
.quota_manager
.increment_usage(org_id, UsageMetric::ApiCalls, 1)
.await
{
tracing::warn!("Failed to increment API call count for org {}: {}", org_id, e);
}
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Warning",
message.parse().unwrap_or_else(|_| HeaderValue::from_static("rate limit warning")),
);
headers.insert(
"X-RateLimit-Usage-Percent",
percentage
.to_string()
.parse()
.unwrap_or_else(|_| HeaderValue::from_static("0")),
);
response
}
QuotaAction::Block { message } => RateLimitExceededResponse { message }.into_response(),
}
}
pub async fn message_quota_middleware(
State(state): State<Arc<QuotaMiddlewareState>>,
request: Request<Body>,
next: Next,
) -> Response {
if !state.is_enabled().await {
return next.run(request).await;
}
let org_id = match extract_organization_id(&request) {
Some(id) => id,
None => return next.run(request).await,
};
let action = state
.quota_manager
.check_action(org_id, UsageMetric::Messages)
.await;
match action {
QuotaAction::Allow => {
let response = next.run(request).await;
if response.status().is_success() {
if let Err(e) = state
.quota_manager
.increment_usage(org_id, UsageMetric::Messages, 1)
.await
{
tracing::warn!("Failed to increment message count for org {}: {}", org_id, e);
}
}
response
}
QuotaAction::Warn { message, percentage } => {
let response = next.run(request).await;
if response.status().is_success() {
if let Err(e) = state
.quota_manager
.increment_usage(org_id, UsageMetric::Messages, 1)
.await
{
tracing::warn!("Failed to increment message count for org {}: {}", org_id, e);
}
}
let mut response = response;
let headers = response.headers_mut();
headers.insert(
"X-Message-Quota-Warning",
message.parse().unwrap_or_else(|_| HeaderValue::from_static("message quota warning")),
);
headers.insert(
"X-Message-Quota-Usage-Percent",
percentage
.to_string()
.parse()
.unwrap_or_else(|_| HeaderValue::from_static("0")),
);
response
}
QuotaAction::Block { message } => MessageQuotaExceededResponse { message }.into_response(),
}
}
pub async fn storage_check_middleware(
State(state): State<Arc<QuotaMiddlewareState>>,
request: Request<Body>,
next: Next,
) -> Response {
if !state.is_enabled().await {
return next.run(request).await;
}
let org_id = match extract_organization_id(&request) {
Some(id) => id,
None => return next.run(request).await,
};
let content_length = request
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
if content_length == 0 {
return next.run(request).await;
}
let check_result = state
.quota_manager
.check_quota(org_id, UsageMetric::StorageBytes, content_length)
.await;
match check_result {
Ok(result) => {
if !result.allowed {
return StorageQuotaExceededResponse {
message: format!(
"Storage quota exceeded. Current: {} bytes, Limit: {:?} bytes",
result.current, result.limit
),
current_usage: result.current,
limit: result.limit,
}
.into_response();
}
let response = next.run(request).await;
if response.status().is_success() {
if let Err(e) = state
.quota_manager
.increment_usage(org_id, UsageMetric::StorageBytes, content_length)
.await
{
tracing::warn!("Failed to increment storage for org {}: {}", org_id, e);
}
}
response
}
Err(e) => {
tracing::error!("Failed to check storage quota for org {}: {}", org_id, e);
next.run(request).await
}
}
}
fn extract_organization_id(request: &Request<Body>) -> Option<Uuid> {
if let Some(org_header) = request.headers().get("X-Organization-Id") {
if let Ok(org_str) = org_header.to_str() {
if let Ok(org_id) = Uuid::parse_str(org_str) {
return Some(org_id);
}
}
}
if let Some(query) = request.uri().query() {
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=') {
if key == "org_id" || key == "organization_id" {
if let Ok(org_id) = Uuid::parse_str(value) {
return Some(org_id);
}
}
}
}
}
request
.extensions()
.get::<OrganizationContext>()
.map(|ctx| ctx.organization_id)
}
fn determine_metric_from_request(request: &Request<Body>) -> UsageMetric {
let path = request.uri().path();
let method = request.method();
if path.contains("/chat") || path.contains("/message") || path.contains("/conversation") {
return UsageMetric::Messages;
}
if path.contains("/upload") || path.contains("/file") || path.contains("/storage") {
return UsageMetric::StorageBytes;
}
if path.contains("/bot") && method == "POST" {
return UsageMetric::Bots;
}
if path.contains("/user") && method == "POST" {
return UsageMetric::Users;
}
if path.contains("/kb") || path.contains("/document") {
return UsageMetric::KbDocuments;
}
if path.contains("/app") || path.contains("/form") || path.contains("/site") {
return UsageMetric::Apps;
}
UsageMetric::ApiCalls
}
#[derive(Clone)]
pub struct OrganizationContext {
pub organization_id: Uuid,
pub user_id: Option<Uuid>,
pub plan_id: Option<String>,
}
pub async fn organization_context_middleware(
request: Request<Body>,
next: Next,
) -> Response {
let org_id = extract_organization_id(&request);
if let Some(org_id) = org_id {
let mut request = request;
request.extensions_mut().insert(OrganizationContext {
organization_id: org_id,
user_id: None,
plan_id: None,
});
next.run(request).await
} else {
next.run(request).await
}
}
struct QuotaExceededResponse {
message: String,
}
impl IntoResponse for QuotaExceededResponse {
fn into_response(self) -> Response {
let body = serde_json::json!({
"error": "quota_exceeded",
"message": self.message,
"code": "QUOTA_EXCEEDED"
});
(
StatusCode::TOO_MANY_REQUESTS,
[
("Content-Type", "application/json"),
("X-Quota-Exceeded", "true"),
],
Json(body),
)
.into_response()
}
}
struct RateLimitExceededResponse {
message: String,
}
impl IntoResponse for RateLimitExceededResponse {
fn into_response(self) -> Response {
let body = serde_json::json!({
"error": "rate_limit_exceeded",
"message": self.message,
"code": "RATE_LIMIT_EXCEEDED",
"retry_after": 60
});
(
StatusCode::TOO_MANY_REQUESTS,
[
("Content-Type", "application/json"),
("Retry-After", "60"),
("X-RateLimit-Exceeded", "true"),
],
Json(body),
)
.into_response()
}
}
struct MessageQuotaExceededResponse {
message: String,
}
impl IntoResponse for MessageQuotaExceededResponse {
fn into_response(self) -> Response {
let body = serde_json::json!({
"error": "message_quota_exceeded",
"message": self.message,
"code": "MESSAGE_QUOTA_EXCEEDED",
"upgrade_url": "/billing/upgrade"
});
(
StatusCode::TOO_MANY_REQUESTS,
[
("Content-Type", "application/json"),
("X-Message-Quota-Exceeded", "true"),
],
Json(body),
)
.into_response()
}
}
struct StorageQuotaExceededResponse {
message: String,
current_usage: u64,
limit: Option<u64>,
}
impl IntoResponse for StorageQuotaExceededResponse {
fn into_response(self) -> Response {
let body = serde_json::json!({
"error": "storage_quota_exceeded",
"message": self.message,
"code": "STORAGE_QUOTA_EXCEEDED",
"current_usage_bytes": self.current_usage,
"limit_bytes": self.limit,
"upgrade_url": "/billing/upgrade"
});
(
StatusCode::PAYLOAD_TOO_LARGE,
[
("Content-Type", "application/json"),
("X-Storage-Quota-Exceeded", "true"),
],
Json(body),
)
.into_response()
}
}
pub fn create_quota_middleware_state(quota_manager: Arc<QuotaManager>) -> Arc<QuotaMiddlewareState> {
Arc::new(QuotaMiddlewareState::new(quota_manager))
}
pub async fn toggle_saas_mode(state: &QuotaMiddlewareState, enabled: bool) {
state.set_enabled(enabled).await;
if enabled {
tracing::info!("SaaS quota enforcement enabled");
} else {
tracing::info!("SaaS quota enforcement disabled (local mode)");
}
}