diff --git a/src/basic/keywords/use_kb.rs b/src/basic/keywords/use_kb.rs index a9918578..01e74e71 100644 --- a/src/basic/keywords/use_kb.rs +++ b/src/basic/keywords/use_kb.rs @@ -57,9 +57,10 @@ pub fn register_use_kb_keyword( let conn = state_clone_for_syntax.conn.clone(); let kb_name_clone = kb_name.clone(); - let result = - std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone)) - .join(); + let result = std::thread::spawn(move || { + add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone) + }) + .join(); match result { Ok(Ok(_)) => { @@ -103,9 +104,10 @@ pub fn register_use_kb_keyword( let conn = state_clone_lower.conn.clone(); let kb_name_clone = kb_name.to_string(); - let result = - std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone)) - .join(); + let result = std::thread::spawn(move || { + add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone) + }) + .join(); match result { Ok(Ok(_)) => { @@ -135,9 +137,10 @@ pub fn register_use_kb_keyword( let conn = state_clone2.conn.clone(); let kb_name_clone = kb_name.to_string(); - let result = - std::thread::spawn(move || add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone)) - .join(); + let result = std::thread::spawn(move || { + add_kb_to_session(conn, session_id, bot_id, user_id, &kb_name_clone) + }) + .join(); match result { Ok(Ok(_)) => { @@ -185,7 +188,11 @@ fn add_kb_to_session( .map_err(|e| format!("Failed to check KB existence: {}", e))?; let (kb_folder_path, qdrant_collection) = if let Some(kb_result) = kb_exists { - // CHECK ACCESS + #[derive(QueryableByName)] + struct AccessCheck { + #[diesel(sql_type = diesel::sql_types::Bool)] + exists: bool, + } let has_access: bool = diesel::sql_query( "SELECT EXISTS ( SELECT 1 FROM kb_collections kc @@ -198,12 +205,13 @@ fn add_kb_to_session( WHERE kga.kb_id = kc.id AND rug.user_id = $2 ) ) - )" + ) AS exists", ) .bind::(kb_result.id) .bind::(user_id) - .get_result::(&mut conn) - .map_err(|e| format!("Failed to check KB access: {}", e))?; + .get_result::(&mut conn) + .map_err(|e| format!("Failed to check KB access: {}", e))? + .exists; if !has_access { return Err(format!("Access denied for KB '{}'", kb_name)); diff --git a/src/drive/mod.rs b/src/drive/mod.rs index 141962b7..3a97061f 100644 --- a/src/drive/mod.rs +++ b/src/drive/mod.rs @@ -12,6 +12,9 @@ use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; use serde::{Deserialize, Serialize}; use std::sync::Arc; +use tokio::task::JoinError; +use diesel::prelude::*; +use diesel::sql_types::*; pub mod drive_types; pub mod drive_handlers; @@ -357,17 +360,28 @@ pub async fn list_files( let mut items = Vec::new(); let prefix = params.path.as_deref().unwrap_or(""); - // Fetch KBs from database to mark them in the list let kbs: Vec<(String, bool)> = { let conn = state.conn.clone(); - tokio::task::spawn_blocking(move || { + let kbs_result = tokio::task::spawn_blocking(move || -> Result, String> { + #[derive(QueryableByName)] + struct KbRow { + #[diesel(sql_type = diesel::sql_types::Text)] + name: String, + #[diesel(sql_type = diesel::sql_types::Bool)] + is_public: bool, + } let mut db_conn = conn.get().map_err(|e| e.to_string())?; - use crate::core::shared::models::schema::kb_collections; - kb_collections::table - .select((kb_collections::name, kb_collections::is_public)) - .load::<(String, bool)>(&mut db_conn) - .map_err(|e| e.to_string()) - }).await.unwrap_or(Ok(vec![])).unwrap_or_default() + let rows: Vec = diesel::sql_query( + "SELECT name, COALESCE(is_public, false) as is_public FROM kb_collections" + ) + .load(&mut db_conn) + .map_err(|e| e.to_string())?; + Ok(rows.into_iter().map(|r| (r.name, r.is_public)).collect()) + }).await; + match kbs_result { + Ok(Ok(kbs)) => kbs, + _ => vec![], + } }; let paginator = s3_client @@ -401,6 +415,8 @@ pub async fn list_files( size: None, modified: None, icon: get_file_icon(&dir), + is_kb: false, + is_public: true, }); } } @@ -946,6 +962,8 @@ pub async fn search_files( size: obj.size(), modified: obj.last_modified().map(|t| t.to_string()), icon: get_file_icon(key), + is_kb: false, + is_public: true, }); } } else { @@ -956,6 +974,8 @@ pub async fn search_files( size: obj.size(), modified: obj.last_modified().map(|t| t.to_string()), icon: get_file_icon(key), + is_kb: false, + is_public: true, }); } } @@ -1016,6 +1036,8 @@ pub async fn recent_files( size: obj.size(), modified: obj.last_modified().map(|t| t.to_string()), icon: get_file_icon(key), + is_kb: false, + is_public: true, }); } } diff --git a/src/settings/rbac_kb.rs b/src/settings/rbac_kb.rs index f4fddfba..abed786e 100644 --- a/src/settings/rbac_kb.rs +++ b/src/settings/rbac_kb.rs @@ -18,6 +18,9 @@ use axum::{ response::IntoResponse, Json, }; +use diesel::prelude::*; +use diesel::sql_types::*; +use tokio::task::JoinError; use chrono::Utc; use diesel::prelude::*; use log::info; @@ -59,16 +62,55 @@ pub async fn get_kb_groups( Path(kb_id): Path, ) -> impl IntoResponse { let conn = state.conn.clone(); - let result = tokio::task::spawn_blocking(move || { + let result: Result, String>, JoinError> = tokio::task::spawn_blocking(move || { let mut db_conn = conn.get().map_err(|e| format!("DB error: {e}"))?; - use crate::core::shared::models::schema::{kb_group_associations, rbac_groups}; - kb_group_associations::table - .inner_join(rbac_groups::table.on(rbac_groups::id.eq(kb_group_associations::group_id))) - .filter(kb_group_associations::kb_id.eq(kb_id)) - .filter(rbac_groups::is_active.eq(true)) - .select(RbacGroup::as_select()) - .load::(&mut db_conn) - .map_err(|e| format!("Query error: {e}")) + + #[derive(QueryableByName)] + struct GroupRow { + #[diesel(sql_type = diesel::sql_types::Uuid)] + id: Uuid, + #[diesel(sql_type = diesel::sql_types::Text)] + name: String, + #[diesel(sql_type = diesel::sql_types::Text)] + display_name: String, + #[diesel(sql_type = diesel::sql_types::Nullable)] + description: Option, + #[diesel(sql_type = diesel::sql_types::Bool)] + is_active: bool, + #[diesel(sql_type = diesel::sql_types::Nullable)] + parent_group_id: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + created_by: Option, + #[diesel(sql_type = diesel::sql_types::Timestamptz)] + created_at: chrono::DateTime, + #[diesel(sql_type = diesel::sql_types::Timestamptz)] + updated_at: chrono::DateTime, + } + + let rows: Vec = diesel::sql_query( + "SELECT rg.id, rg.name, rg.display_name, rg.description, rg.is_active, + rg.parent_group_id, rg.created_by, rg.created_at, rg.updated_at + FROM research.kb_group_associations kga + JOIN core.rbac_groups rg ON rg.id = kga.group_id + WHERE kga.kb_id = $1 AND rg.is_active = true" + ) + .bind::(kb_id) + .load(&mut db_conn) + .map_err(|e| format!("Query error: {e}"))?; + + let groups: Vec = rows.into_iter().map(|r| RbacGroup { + id: r.id, + name: r.name, + display_name: r.display_name, + description: r.description, + is_active: r.is_active, + parent_group_id: r.parent_group_id, + created_by: r.created_by, + created_at: r.created_at, + updated_at: r.updated_at, + }).collect(); + + Ok(groups) }) .await;