Some checks are pending
BotServer CI/CD / build (push) Waiting to run
- Remove all std::env::var calls except VAULT_* and PORT - get_from_env returns hardcoded defaults only (no env var reading) - Auth config, rate limits, email, analytics, calendar all use Vault - WORK_PATH replaced with get_work_path() helper reading from Vault - .env on production cleaned to only VAULT_ADDR, VAULT_TOKEN, VAULT_CACERT, PORT - All service IPs/credentials stored in Vault secret/gbo/*
238 lines
6.3 KiB
Rust
238 lines
6.3 KiB
Rust
use axum::{
|
|
extract::{ConnectInfo, Request, State},
|
|
http::StatusCode,
|
|
middleware::Next,
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use governor::{
|
|
clock::DefaultClock,
|
|
middleware::NoOpMiddleware,
|
|
state::{InMemoryState, NotKeyed},
|
|
Quota, RateLimiter,
|
|
};
|
|
use std::{collections::HashMap, net::SocketAddr, num::NonZeroU32, sync::Arc};
|
|
use tokio::sync::RwLock;
|
|
|
|
type Limiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
|
|
|
|
pub struct KeyedRateLimiter {
|
|
limiters: RwLock<HashMap<String, Arc<Limiter>>>,
|
|
quota: Quota,
|
|
cleanup_threshold: usize,
|
|
}
|
|
|
|
impl KeyedRateLimiter {
|
|
pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
|
|
let quota =
|
|
Quota::per_second(NonZeroU32::new(requests_per_second).unwrap_or(NonZeroU32::MIN))
|
|
.allow_burst(NonZeroU32::new(burst_size).unwrap_or(NonZeroU32::MIN));
|
|
|
|
Self {
|
|
limiters: RwLock::new(HashMap::new()),
|
|
quota,
|
|
cleanup_threshold: 10000,
|
|
}
|
|
}
|
|
|
|
pub async fn check(&self, key: &str) -> bool {
|
|
let limiter = {
|
|
let limiters = self.limiters.read().await;
|
|
limiters.get(key).cloned()
|
|
};
|
|
|
|
let limiter = match limiter {
|
|
Some(l) => l,
|
|
None => {
|
|
let mut limiters = self.limiters.write().await;
|
|
|
|
if limiters.len() > self.cleanup_threshold {
|
|
limiters.clear();
|
|
}
|
|
|
|
let new_limiter = Arc::new(RateLimiter::direct(self.quota));
|
|
limiters.insert(key.to_string(), Arc::clone(&new_limiter));
|
|
new_limiter
|
|
}
|
|
};
|
|
|
|
limiter.check().is_ok()
|
|
}
|
|
|
|
pub async fn remaining(&self, key: &str) -> Option<u32> {
|
|
let limiters = self.limiters.read().await;
|
|
limiters.get(key).map(|l| l.check().map(|_| 1).unwrap_or(0))
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Debug for KeyedRateLimiter {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("KeyedRateLimiter")
|
|
.field("cleanup_threshold", &self.cleanup_threshold)
|
|
.field(
|
|
"limiters",
|
|
&format!("<{} entries>", self.limiters.blocking_read().len()),
|
|
)
|
|
.finish_non_exhaustive()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct RateLimitConfig {
|
|
pub api_rps: u32,
|
|
|
|
pub api_burst: u32,
|
|
|
|
pub auth_rps: u32,
|
|
|
|
pub auth_burst: u32,
|
|
|
|
pub llm_rps: u32,
|
|
|
|
pub llm_burst: u32,
|
|
|
|
pub enabled: bool,
|
|
}
|
|
|
|
impl Default for RateLimitConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
api_rps: 100,
|
|
api_burst: 200,
|
|
auth_rps: 10,
|
|
auth_burst: 20,
|
|
llm_rps: 5,
|
|
llm_burst: 10,
|
|
enabled: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct RateLimitState {
|
|
pub config: RateLimitConfig,
|
|
pub api_limiter: KeyedRateLimiter,
|
|
pub auth_limiter: KeyedRateLimiter,
|
|
pub llm_limiter: KeyedRateLimiter,
|
|
}
|
|
|
|
impl RateLimitState {
|
|
pub fn new(config: RateLimitConfig) -> Self {
|
|
Self {
|
|
api_limiter: KeyedRateLimiter::new(config.api_rps, config.api_burst),
|
|
auth_limiter: KeyedRateLimiter::new(config.auth_rps, config.auth_burst),
|
|
llm_limiter: KeyedRateLimiter::new(config.llm_rps, config.llm_burst),
|
|
config,
|
|
}
|
|
}
|
|
|
|
pub fn from_env() -> Self {
|
|
let config = RateLimitConfig {
|
|
api_rps: 100,
|
|
api_burst: 200,
|
|
auth_rps: 10,
|
|
auth_burst: 20,
|
|
llm_rps: 5,
|
|
llm_burst: 10,
|
|
enabled: true,
|
|
};
|
|
Self::new(config)
|
|
}
|
|
}
|
|
|
|
fn get_client_ip(req: &Request) -> String {
|
|
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
|
|
if let Ok(value) = forwarded.to_str() {
|
|
if let Some(ip) = value.split(',').next() {
|
|
return ip.trim().to_string();
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(real_ip) = req.headers().get("x-real-ip") {
|
|
if let Ok(value) = real_ip.to_str() {
|
|
return value.to_string();
|
|
}
|
|
}
|
|
|
|
req.extensions()
|
|
.get::<ConnectInfo<SocketAddr>>()
|
|
.map(|ci| ci.0.ip().to_string())
|
|
.unwrap_or_else(|| "unknown".to_string())
|
|
}
|
|
|
|
fn get_limiter_type(path: &str) -> LimiterType {
|
|
if path.contains("/auth") || path.contains("/login") || path.contains("/token") {
|
|
LimiterType::Auth
|
|
} else if path.contains("/llm") || path.contains("/chat") || path.contains("/generate") {
|
|
LimiterType::Llm
|
|
} else {
|
|
LimiterType::Api
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
enum LimiterType {
|
|
Api,
|
|
Auth,
|
|
Llm,
|
|
}
|
|
|
|
pub async fn rate_limit_middleware(
|
|
State(state): State<Arc<RateLimitState>>,
|
|
req: Request,
|
|
next: Next,
|
|
) -> Response {
|
|
if !state.config.enabled {
|
|
return next.run(req).await;
|
|
}
|
|
|
|
let client_ip = get_client_ip(&req);
|
|
let path = req.uri().path();
|
|
let limiter_type = get_limiter_type(path);
|
|
|
|
let allowed = match limiter_type {
|
|
LimiterType::Api => state.api_limiter.check(&client_ip).await,
|
|
LimiterType::Auth => state.auth_limiter.check(&client_ip).await,
|
|
LimiterType::Llm => state.llm_limiter.check(&client_ip).await,
|
|
};
|
|
|
|
if allowed {
|
|
next.run(req).await
|
|
} else {
|
|
rate_limit_response(limiter_type)
|
|
}
|
|
}
|
|
|
|
fn rate_limit_response(limiter_type: LimiterType) -> Response {
|
|
let (retry_after, message) = match limiter_type {
|
|
LimiterType::Api => (1, "API rate limit exceeded"),
|
|
LimiterType::Auth => (
|
|
60,
|
|
"Authentication rate limit exceeded. Please wait before trying again.",
|
|
),
|
|
LimiterType::Llm => (
|
|
10,
|
|
"LLM rate limit exceeded. Please wait before sending another request.",
|
|
),
|
|
};
|
|
|
|
let body = serde_json::json!({
|
|
"error": "rate_limit_exceeded",
|
|
"message": message,
|
|
"retry_after": retry_after
|
|
});
|
|
|
|
(
|
|
StatusCode::TOO_MANY_REQUESTS,
|
|
[
|
|
("Retry-After", retry_after.to_string()),
|
|
("Content-Type", "application/json".to_string()),
|
|
],
|
|
body.to_string(),
|
|
)
|
|
.into_response()
|
|
}
|
|
|
|
pub fn create_rate_limit_state(config: RateLimitConfig) -> Arc<RateLimitState> {
|
|
Arc::new(RateLimitState::new(config))
|
|
}
|