botserver/src/security/rate_limiter.rs

257 lines
6.8 KiB
Rust

use axum::{
body::Body,
extract::ConnectInfo,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use botlib::{
format_limit_error_response, LimitExceeded, RateLimiter as BotlibRateLimiter, SystemLimits,
};
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter as GovernorRateLimiter,
};
use serde_json::json;
use std::{
net::SocketAddr,
num::NonZeroU32,
sync::Arc,
};
pub type GlobalRateLimiter = GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>;
#[derive(Debug, Clone)]
pub struct HttpRateLimitConfig {
pub requests_per_second: u32,
pub burst_size: u32,
}
impl Default for HttpRateLimitConfig {
fn default() -> Self {
Self {
requests_per_second: 100,
burst_size: 200,
}
}
}
impl HttpRateLimitConfig {
pub fn strict() -> Self {
Self {
requests_per_second: 50,
burst_size: 100,
}
}
pub fn relaxed() -> Self {
Self {
requests_per_second: 500,
burst_size: 1000,
}
}
pub fn api() -> Self {
Self {
requests_per_second: 100,
burst_size: 150,
}
}
}
pub struct CombinedRateLimiter {
http_limiter: Arc<GlobalRateLimiter>,
botlib_limiter: Arc<BotlibRateLimiter>,
}
impl CombinedRateLimiter {
pub fn new(http_config: HttpRateLimitConfig, system_limits: SystemLimits) -> Self {
const DEFAULT_RPS: NonZeroU32 = match NonZeroU32::new(100) {
Some(v) => v,
None => unreachable!(),
};
const DEFAULT_BURST: NonZeroU32 = match NonZeroU32::new(200) {
Some(v) => v,
None => unreachable!(),
};
let quota = Quota::per_second(
NonZeroU32::new(http_config.requests_per_second).unwrap_or(DEFAULT_RPS),
)
.allow_burst(
NonZeroU32::new(http_config.burst_size).unwrap_or(DEFAULT_BURST),
);
Self {
http_limiter: Arc::new(GovernorRateLimiter::direct(quota)),
botlib_limiter: Arc::new(BotlibRateLimiter::new(system_limits)),
}
}
pub fn with_defaults() -> Self {
Self::new(HttpRateLimitConfig::default(), SystemLimits::default())
}
pub fn check_http_limit(&self) -> bool {
self.http_limiter.check().is_ok()
}
pub async fn check_user_limit(&self, user_id: &str) -> Result<(), LimitExceeded> {
self.botlib_limiter.check_rate_limit(user_id).await
}
pub fn botlib_limiter(&self) -> &Arc<BotlibRateLimiter> {
&self.botlib_limiter
}
pub async fn cleanup(&self) {
self.botlib_limiter.cleanup_stale_entries().await;
}
}
impl Clone for CombinedRateLimiter {
fn clone(&self) -> Self {
Self {
http_limiter: Arc::clone(&self.http_limiter),
botlib_limiter: Arc::clone(&self.botlib_limiter),
}
}
}
pub async fn rate_limit_middleware(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
axum::Extension(limiter): axum::Extension<Arc<CombinedRateLimiter>>,
request: Request<Body>,
next: Next,
) -> Response {
if !limiter.check_http_limit() {
return http_rate_limit_response(30);
}
let user_id = extract_user_id(&request).unwrap_or_else(|| addr.ip().to_string());
match limiter.check_user_limit(&user_id).await {
Ok(()) => next.run(request).await,
Err(limit_exceeded) => {
let (status, body) = format_limit_error_response(&limit_exceeded);
(StatusCode::from_u16(status).unwrap_or(StatusCode::TOO_MANY_REQUESTS), body).into_response()
}
}
}
pub async fn simple_rate_limit_middleware(
axum::Extension(limiter): axum::Extension<Arc<CombinedRateLimiter>>,
request: Request<Body>,
next: Next,
) -> Response {
if !limiter.check_http_limit() {
return http_rate_limit_response(30);
}
next.run(request).await
}
fn extract_user_id(request: &Request<Body>) -> Option<String> {
if let Some(user_id) = request.headers().get("x-user-id") {
if let Ok(id) = user_id.to_str() {
return Some(id.to_string());
}
}
if let Some(auth) = request.headers().get("authorization") {
if let Ok(auth_str) = auth.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
if token.len() > 10 {
return Some(format!("token:{}", &token[..10]));
}
}
}
}
None
}
fn http_rate_limit_response(retry_after: u64) -> Response {
let mut response = (
StatusCode::TOO_MANY_REQUESTS,
Json(json!({
"error": "rate_limit_exceeded",
"message": "Too many requests. Please slow down.",
"retry_after_secs": retry_after
})),
)
.into_response();
if let Ok(value) = retry_after.to_string().parse() {
response.headers_mut().insert("Retry-After", value);
}
response
}
pub fn create_rate_limit_layer(
http_config: HttpRateLimitConfig,
system_limits: SystemLimits,
) -> (
axum::Extension<Arc<CombinedRateLimiter>>,
Arc<CombinedRateLimiter>,
) {
let limiter = Arc::new(CombinedRateLimiter::new(http_config, system_limits));
(axum::Extension(Arc::clone(&limiter)), limiter)
}
pub fn create_default_rate_limit_layer() -> (
axum::Extension<Arc<CombinedRateLimiter>>,
Arc<CombinedRateLimiter>,
) {
create_rate_limit_layer(HttpRateLimitConfig::default(), SystemLimits::default())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http_config_presets() {
let default = HttpRateLimitConfig::default();
assert_eq!(default.requests_per_second, 100);
let strict = HttpRateLimitConfig::strict();
assert_eq!(strict.requests_per_second, 50);
let relaxed = HttpRateLimitConfig::relaxed();
assert_eq!(relaxed.requests_per_second, 500);
let api = HttpRateLimitConfig::api();
assert_eq!(api.requests_per_second, 100);
}
#[test]
fn test_combined_limiter_creation() {
let limiter = CombinedRateLimiter::with_defaults();
assert!(limiter.check_http_limit());
}
#[test]
fn test_combined_limiter_clone() {
let limiter = CombinedRateLimiter::with_defaults();
let cloned = limiter.clone();
assert!(cloned.check_http_limit());
}
#[tokio::test]
async fn test_user_rate_limit() {
let limiter = CombinedRateLimiter::with_defaults();
let result = limiter.check_user_limit("test-user").await;
assert!(result.is_ok());
}
#[test]
fn test_extract_user_id_none() {
let request = Request::builder()
.body(Body::empty())
.expect("valid syntax registration");
assert!(extract_user_id(&request).is_none());
}
}