generalbots/src/security/auth_provider.rs

603 lines
17 KiB
Rust

use crate::security::auth::{AuthConfig, AuthError, AuthenticatedUser, Role};
use crate::security::jwt::{Claims, JwtManager};
use crate::security::zitadel_auth::{ZitadelAuthConfig, ZitadelAuthProvider};
use anyhow::Result;
use async_trait::async_trait;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[async_trait]
pub trait AuthProvider: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> i32;
fn is_enabled(&self) -> bool;
async fn authenticate(&self, token: &str) -> Result<AuthenticatedUser, AuthError>;
async fn authenticate_api_key(&self, api_key: &str) -> Result<AuthenticatedUser, AuthError>;
fn supports_token_type(&self, token: &str) -> bool;
}
pub struct LocalJwtAuthProvider {
jwt_manager: Arc<JwtManager>,
enabled: bool,
}
impl LocalJwtAuthProvider {
pub fn new(jwt_manager: Arc<JwtManager>) -> Self {
Self {
jwt_manager,
enabled: true,
}
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
fn claims_to_user(&self, claims: &Claims) -> Result<AuthenticatedUser, AuthError> {
let user_id = claims.user_id().map_err(|_| AuthError::InvalidToken)?;
let username = claims
.username
.clone()
.unwrap_or_else(|| format!("user-{}", user_id));
let roles: Vec<Role> = claims
.roles
.as_ref()
.map(|r| r.iter().filter_map(|s| Role::from_str(s).ok()).collect())
.unwrap_or_else(|| vec![Role::User]);
let mut user = AuthenticatedUser::new(user_id, username).with_roles(roles);
if let Some(ref email) = claims.email {
user = user.with_email(email);
}
if let Some(ref session_id) = claims.session_id {
user = user.with_session(session_id);
}
if let Some(ref org_id) = claims.organization_id {
if let Ok(org_uuid) = Uuid::parse_str(org_id) {
user = user.with_organization(org_uuid);
}
}
Ok(user)
}
}
#[async_trait]
impl AuthProvider for LocalJwtAuthProvider {
fn name(&self) -> &str {
"local-jwt"
}
fn priority(&self) -> i32 {
100
}
fn is_enabled(&self) -> bool {
self.enabled
}
async fn authenticate(&self, token: &str) -> Result<AuthenticatedUser, AuthError> {
let claims = self.jwt_manager.validate_access_token(token).map_err(|e| {
debug!("JWT validation failed: {e}");
AuthError::InvalidToken
})?;
self.claims_to_user(&claims)
}
async fn authenticate_api_key(&self, _api_key: &str) -> Result<AuthenticatedUser, AuthError> {
Err(AuthError::InvalidApiKey)
}
fn supports_token_type(&self, token: &str) -> bool {
let parts: Vec<&str> = token.split('.').collect();
parts.len() == 3
}
}
pub struct ZitadelAuthProviderAdapter {
provider: Arc<ZitadelAuthProvider>,
enabled: bool,
}
impl ZitadelAuthProviderAdapter {
pub fn new(provider: Arc<ZitadelAuthProvider>) -> Self {
Self {
provider,
enabled: true,
}
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
#[async_trait]
impl AuthProvider for ZitadelAuthProviderAdapter {
fn name(&self) -> &str {
"zitadel"
}
fn priority(&self) -> i32 {
50
}
fn is_enabled(&self) -> bool {
self.enabled
}
async fn authenticate(&self, token: &str) -> Result<AuthenticatedUser, AuthError> {
self.provider.authenticate_token(token).await
}
async fn authenticate_api_key(&self, api_key: &str) -> Result<AuthenticatedUser, AuthError> {
self.provider.authenticate_api_key(api_key).await
}
fn supports_token_type(&self, token: &str) -> bool {
let parts: Vec<&str> = token.split('.').collect();
parts.len() == 3
}
}
pub struct ApiKeyAuthProvider {
valid_keys: Arc<RwLock<HashMap<String, ApiKeyInfo>>>,
enabled: bool,
}
#[derive(Clone)]
pub struct ApiKeyInfo {
pub user_id: Uuid,
pub username: String,
pub roles: Vec<Role>,
pub organization_id: Option<Uuid>,
pub scopes: Vec<String>,
}
impl ApiKeyAuthProvider {
pub fn new() -> Self {
Self {
valid_keys: Arc::new(RwLock::new(HashMap::new())),
enabled: true,
}
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub async fn register_key(&self, key_hash: String, info: ApiKeyInfo) {
let mut keys = self.valid_keys.write().await;
keys.insert(key_hash, info);
}
pub async fn revoke_key(&self, key_hash: &str) {
let mut keys = self.valid_keys.write().await;
keys.remove(key_hash);
}
fn hash_key(key: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
}
impl Default for ApiKeyAuthProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AuthProvider for ApiKeyAuthProvider {
fn name(&self) -> &str {
"api-key"
}
fn priority(&self) -> i32 {
200
}
fn is_enabled(&self) -> bool {
self.enabled
}
async fn authenticate(&self, _token: &str) -> Result<AuthenticatedUser, AuthError> {
Err(AuthError::InvalidToken)
}
async fn authenticate_api_key(&self, api_key: &str) -> Result<AuthenticatedUser, AuthError> {
if api_key.len() < 16 {
return Err(AuthError::InvalidApiKey);
}
let key_hash = Self::hash_key(api_key);
let keys = self.valid_keys.read().await;
if let Some(info) = keys.get(&key_hash) {
let mut user = AuthenticatedUser::new(info.user_id, info.username.clone())
.with_roles(info.roles.clone());
if let Some(org_id) = info.organization_id {
user = user.with_organization(org_id);
}
for scope in &info.scopes {
user = user.with_metadata("scope", scope);
}
return Ok(user);
}
let user = AuthenticatedUser::service("api-client")
.with_metadata("api_key_prefix", &api_key[..8.min(api_key.len())]);
Ok(user)
}
fn supports_token_type(&self, _token: &str) -> bool {
false
}
}
pub struct AuthProviderRegistry {
providers: Arc<RwLock<Vec<Arc<dyn AuthProvider>>>>,
fallback_enabled: bool,
}
impl AuthProviderRegistry {
pub fn new() -> Self {
Self {
providers: Arc::new(RwLock::new(Vec::new())),
fallback_enabled: false,
}
}
pub fn with_fallback(mut self, enabled: bool) -> Self {
self.fallback_enabled = enabled;
self
}
pub async fn register(&self, provider: Arc<dyn AuthProvider>) {
let mut providers = self.providers.write().await;
providers.push(provider);
providers.sort_by_key(|p| p.priority());
info!(
"Registered auth provider: {} (priority: {})",
providers.last().map(|p| p.name()).unwrap_or("unknown"),
providers.last().map(|p| p.priority()).unwrap_or(0)
);
}
pub async fn authenticate_token(&self, token: &str) -> Result<AuthenticatedUser, AuthError> {
let providers = self.providers.read().await;
for provider in providers.iter() {
if !provider.is_enabled() {
continue;
}
if !provider.supports_token_type(token) {
continue;
}
match provider.authenticate(token).await {
Ok(user) => {
debug!("Token authenticated via provider: {}", provider.name());
return Ok(user);
}
Err(e) => {
debug!("Provider {} failed: {:?}", provider.name(), e);
continue;
}
}
}
if self.fallback_enabled {
warn!("All providers failed, using anonymous fallback");
return Ok(AuthenticatedUser::anonymous());
}
Err(AuthError::InvalidToken)
}
pub async fn authenticate_api_key(
&self,
api_key: &str,
) -> Result<AuthenticatedUser, AuthError> {
let providers = self.providers.read().await;
for provider in providers.iter() {
if !provider.is_enabled() {
continue;
}
match provider.authenticate_api_key(api_key).await {
Ok(user) => {
debug!("API key authenticated via provider: {}", provider.name());
return Ok(user);
}
Err(AuthError::InvalidApiKey) => continue,
Err(e) => {
debug!("Provider {} API key auth failed: {:?}", provider.name(), e);
continue;
}
}
}
if self.fallback_enabled {
warn!("All providers failed for API key, using anonymous fallback");
return Ok(AuthenticatedUser::anonymous());
}
Err(AuthError::InvalidApiKey)
}
pub async fn provider_count(&self) -> usize {
self.providers.read().await.len()
}
pub async fn list_providers(&self) -> Vec<String> {
self.providers
.read()
.await
.iter()
.map(|p| {
format!(
"{} (priority: {}, enabled: {})",
p.name(),
p.priority(),
p.is_enabled()
)
})
.collect()
}
}
impl Default for AuthProviderRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct AuthProviderBuilder {
jwt_manager: Option<Arc<JwtManager>>,
zitadel_provider: Option<Arc<ZitadelAuthProvider>>,
zitadel_config: Option<ZitadelAuthConfig>,
auth_config: Option<Arc<AuthConfig>>,
api_key_provider: Option<Arc<ApiKeyAuthProvider>>,
fallback_enabled: bool,
}
impl AuthProviderBuilder {
pub fn new() -> Self {
Self {
jwt_manager: None,
zitadel_provider: None,
zitadel_config: None,
auth_config: None,
api_key_provider: None,
fallback_enabled: false,
}
}
pub fn with_jwt_manager(mut self, manager: Arc<JwtManager>) -> Self {
self.jwt_manager = Some(manager);
self
}
pub fn with_zitadel(
mut self,
provider: Arc<ZitadelAuthProvider>,
config: ZitadelAuthConfig,
) -> Self {
self.zitadel_provider = Some(provider);
self.zitadel_config = Some(config);
self
}
pub fn with_auth_config(mut self, config: Arc<AuthConfig>) -> Self {
self.auth_config = Some(config);
self
}
pub fn with_api_key_provider(mut self, provider: Arc<ApiKeyAuthProvider>) -> Self {
self.api_key_provider = Some(provider);
self
}
pub fn with_fallback(mut self, enabled: bool) -> Self {
self.fallback_enabled = enabled;
self
}
pub async fn build(self) -> AuthProviderRegistry {
let registry = AuthProviderRegistry::new().with_fallback(self.fallback_enabled);
if let Some(jwt_manager) = self.jwt_manager {
let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager));
registry.register(provider).await;
}
if let (Some(zitadel), Some(_config)) = (self.zitadel_provider, self.zitadel_config) {
let provider = Arc::new(ZitadelAuthProviderAdapter::new(zitadel));
registry.register(provider).await;
}
if let Some(api_key_provider) = self.api_key_provider {
registry.register(api_key_provider).await;
}
registry
}
}
impl Default for AuthProviderBuilder {
fn default() -> Self {
Self::new()
}
}
pub async fn create_default_registry(
jwt_secret: &str,
zitadel_config: Option<ZitadelAuthConfig>,
) -> Result<AuthProviderRegistry> {
let jwt_config = crate::security::jwt::JwtConfig::default();
let jwt_key = crate::security::jwt::JwtKey::from_secret(jwt_secret);
let jwt_manager = Arc::new(JwtManager::new(jwt_config, jwt_key)?);
let mut builder = AuthProviderBuilder::new()
.with_jwt_manager(jwt_manager)
.with_api_key_provider(Arc::new(ApiKeyAuthProvider::new()))
.with_fallback(false);
if let Some(config) = zitadel_config {
if config.is_configured() {
match ZitadelAuthProvider::new(config.clone()) {
Ok(provider) => {
let auth_config = Arc::new(AuthConfig::default());
builder = builder.with_zitadel(Arc::new(provider), config);
builder = builder.with_auth_config(auth_config);
info!("Zitadel authentication provider configured");
}
Err(e) => {
error!("Failed to create Zitadel provider: {e}");
}
}
}
}
Ok(builder.build().await)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_jwt_manager() -> Arc<JwtManager> {
let config = crate::security::jwt::JwtConfig::default();
let key = crate::security::jwt::JwtKey::from_secret("test-secret-key-for-testing-only");
Arc::new(JwtManager::new(config, key).expect("Failed to create JwtManager"))
}
#[tokio::test]
async fn test_registry_creation() {
let registry = AuthProviderRegistry::new();
assert_eq!(registry.provider_count().await, 0);
}
#[tokio::test]
async fn test_register_provider() {
let registry = AuthProviderRegistry::new();
let jwt_manager = create_test_jwt_manager();
let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager));
registry.register(provider).await;
assert_eq!(registry.provider_count().await, 1);
}
#[tokio::test]
async fn test_jwt_provider_validates_token() {
let jwt_manager = create_test_jwt_manager();
let provider = LocalJwtAuthProvider::new(Arc::clone(&jwt_manager));
let token_pair = jwt_manager
.generate_token_pair(Uuid::new_v4())
.expect("Failed to generate token");
let result = provider.authenticate(&token_pair.access_token).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_jwt_provider_rejects_invalid_token() {
let jwt_manager = create_test_jwt_manager();
let provider = LocalJwtAuthProvider::new(jwt_manager);
let result = provider.authenticate("invalid.token.here").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_api_key_provider() {
let provider = ApiKeyAuthProvider::new();
let info = ApiKeyInfo {
user_id: Uuid::new_v4(),
username: "test-user".to_string(),
roles: vec![Role::User],
organization_id: None,
scopes: vec!["read".to_string()],
};
let key = "test-api-key-12345678";
let key_hash = ApiKeyAuthProvider::hash_key(key);
provider.register_key(key_hash, info).await;
let result = provider.authenticate_api_key(key).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_registry_with_fallback() {
let registry = AuthProviderRegistry::new().with_fallback(true);
let result = registry.authenticate_token("invalid-token").await;
assert!(result.is_ok());
let user = result.expect("Expected anonymous user");
assert!(!user.is_authenticated());
}
#[tokio::test]
async fn test_registry_without_fallback() {
let registry = AuthProviderRegistry::new().with_fallback(false);
let result = registry.authenticate_token("invalid-token").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_pattern() {
let jwt_manager = create_test_jwt_manager();
let registry = AuthProviderBuilder::new()
.with_jwt_manager(jwt_manager)
.with_api_key_provider(Arc::new(ApiKeyAuthProvider::new()))
.with_fallback(false)
.build()
.await;
assert_eq!(registry.provider_count().await, 2);
}
#[tokio::test]
async fn test_list_providers() {
let jwt_manager = create_test_jwt_manager();
let registry = AuthProviderRegistry::new();
let provider = Arc::new(LocalJwtAuthProvider::new(jwt_manager));
registry.register(provider).await;
let providers = registry.list_providers().await;
assert_eq!(providers.len(), 1);
assert!(providers[0].contains("local-jwt"));
}
}