generalbots/src/sources/knowledge_base.rs
Rodrigo Rodriguez (Pragmatismo) 479950945b feat(auth): Add OTP password display on bootstrap and fix Zitadel login flow
- Add generate_secure_password() for OTP generation during admin bootstrap
- Display admin credentials (username/password) in console on first run
- Save credentials to ~/.gb-setup-credentials file
- Fix Zitadel client to support PAT token authentication
- Replace OAuth2 password grant with Zitadel Session API for login
- Fix get_current_user to fetch user data from Zitadel session
- Return session_id as access_token for proper authentication
- Set email as verified on user creation to skip verification
- Add password grant type to OAuth application config
- Update directory_setup to include proper redirect URIs
2026-01-06 22:56:35 -03:00

1048 lines
34 KiB
Rust

use crate::shared::state::AppState;
use axum::{
extract::{Multipart, Path, Query, State},
response::{Html, IntoResponse},
routing::{get, post},
Json, Router,
};
use chrono::{DateTime, Utc};
use diesel::prelude::*;
use log::{error, info, warn};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeSource {
pub id: String,
pub name: String,
pub source_type: SourceType,
pub file_path: Option<String>,
pub url: Option<String>,
pub content_hash: String,
pub chunk_count: i32,
pub status: SourceStatus,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub indexed_at: Option<DateTime<Utc>>,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SourceType {
Pdf,
Docx,
Txt,
Markdown,
Html,
Csv,
Xlsx,
Url,
Custom,
}
impl std::fmt::Display for SourceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Pdf => write!(f, "pdf"),
Self::Docx => write!(f, "docx"),
Self::Txt => write!(f, "txt"),
Self::Markdown => write!(f, "markdown"),
Self::Html => write!(f, "html"),
Self::Csv => write!(f, "csv"),
Self::Xlsx => write!(f, "xlsx"),
Self::Url => write!(f, "url"),
Self::Custom => write!(f, "custom"),
}
}
}
impl From<&str> for SourceType {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"pdf" => Self::Pdf,
"docx" | "doc" => Self::Docx,
"txt" | "text" => Self::Txt,
"md" | "markdown" => Self::Markdown,
"html" | "htm" => Self::Html,
"csv" => Self::Csv,
"xlsx" | "xls" => Self::Xlsx,
"url" | "web" => Self::Url,
_ => Self::Custom,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SourceStatus {
Pending,
Processing,
Indexed,
Failed,
Reindexing,
}
impl std::fmt::Display for SourceStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Pending => write!(f, "pending"),
Self::Processing => write!(f, "processing"),
Self::Indexed => write!(f, "indexed"),
Self::Failed => write!(f, "failed"),
Self::Reindexing => write!(f, "reindexing"),
}
}
}
impl From<&str> for SourceStatus {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"processing" => Self::Processing,
"indexed" => Self::Indexed,
"failed" => Self::Failed,
"reindexing" => Self::Reindexing,
_ => Self::Pending,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentChunk {
pub id: String,
pub source_id: String,
pub chunk_index: i32,
pub content: String,
pub token_count: i32,
pub embedding: Option<Vec<f32>>,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UploadResponse {
pub success: bool,
pub source_id: Option<String>,
pub message: String,
pub chunks_created: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRequest {
pub query: String,
pub collection: Option<String>,
pub top_k: Option<usize>,
pub min_score: Option<f32>,
pub include_metadata: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
pub content: String,
pub source_name: String,
pub source_id: String,
pub chunk_index: i32,
pub score: f32,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResponse {
pub results: Vec<QueryResult>,
pub query: String,
pub total_results: usize,
pub processing_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListSourcesQuery {
pub status: Option<String>,
pub source_type: Option<String>,
pub page: Option<i32>,
pub per_page: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReindexRequest {
pub source_ids: Option<Vec<String>>,
pub force: Option<bool>,
}
#[derive(Debug, QueryableByName)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct KnowledgeSourceRow {
#[diesel(sql_type = diesel::sql_types::Text)]
id: String,
#[diesel(sql_type = diesel::sql_types::Text)]
name: String,
#[diesel(sql_type = diesel::sql_types::Text)]
source_type: String,
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
_file_path: Option<String>,
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
_url: Option<String>,
#[diesel(sql_type = diesel::sql_types::Text)]
_content_hash: String,
#[diesel(sql_type = diesel::sql_types::Integer)]
chunk_count: i32,
#[diesel(sql_type = diesel::sql_types::Text)]
status: String,
#[diesel(sql_type = diesel::sql_types::Text)]
created_at: String,
#[diesel(sql_type = diesel::sql_types::Text)]
_updated_at: String,
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
indexed_at: Option<String>,
}
#[derive(Debug, QueryableByName)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct _ChunkRow {
#[diesel(sql_type = diesel::sql_types::Text)]
_id: String,
#[diesel(sql_type = diesel::sql_types::Text)]
_source_id: String,
#[diesel(sql_type = diesel::sql_types::Integer)]
_chunk_index: i32,
#[diesel(sql_type = diesel::sql_types::Text)]
_content: String,
#[diesel(sql_type = diesel::sql_types::Integer)]
_token_count: i32,
}
#[derive(Debug, QueryableByName)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct SearchResultRow {
#[diesel(sql_type = diesel::sql_types::Text)]
_chunk_id: String,
#[diesel(sql_type = diesel::sql_types::Text)]
content: String,
#[diesel(sql_type = diesel::sql_types::Text)]
source_id: String,
#[diesel(sql_type = diesel::sql_types::Text)]
source_name: String,
#[diesel(sql_type = diesel::sql_types::Integer)]
chunk_index: i32,
#[diesel(sql_type = diesel::sql_types::Float)]
score: f32,
}
pub fn configure_knowledge_base_routes() -> Router<Arc<AppState>> {
use crate::core::urls::ApiUrls;
Router::new()
.route(ApiUrls::SOURCES_KB_UPLOAD, post(handle_upload_document))
.route(ApiUrls::SOURCES_KB_LIST, get(handle_list_sources))
.route(ApiUrls::SOURCES_KB_QUERY, post(handle_query_knowledge_base))
.route(ApiUrls::SOURCES_KB_BY_ID, get(handle_get_source).delete(handle_delete_source))
.route(ApiUrls::SOURCES_KB_REINDEX, post(handle_reindex_sources))
.route(ApiUrls::SOURCES_KB_STATS, get(handle_get_stats))
}
pub async fn handle_upload_document(
State(state): State<Arc<AppState>>,
mut multipart: Multipart,
) -> impl IntoResponse {
let mut file_name = String::new();
let mut file_data: Vec<u8> = Vec::new();
let mut collection = "default".to_string();
while let Ok(Some(field)) = multipart.next_field().await {
let name = field.name().unwrap_or("").to_string();
match name.as_str() {
"file" => {
file_name = field.file_name().unwrap_or("unknown").to_string();
if let Ok(data) = field.bytes().await {
file_data = data.to_vec();
}
}
"collection" => {
if let Ok(text) = field.text().await {
collection = text;
}
}
_ => {}
}
}
if file_data.is_empty() {
return Json(UploadResponse {
success: false,
source_id: None,
message: "No file provided".to_string(),
chunks_created: None,
});
}
let extension = file_name.rsplit('.').next().unwrap_or("txt").to_lowercase();
let source_type = SourceType::from(extension.as_str());
let content = match extract_text_content(&file_data, &source_type) {
Ok(text) => text,
Err(e) => {
error!("Failed to extract text from {}: {}", file_name, e);
return Json(UploadResponse {
success: false,
source_id: None,
message: format!("Failed to extract text: {}", e),
chunks_created: None,
});
}
};
let content_hash = compute_content_hash(&content);
let source_id = Uuid::new_v4().to_string();
let chunks = chunk_text(&content, 512, 50);
let chunk_count = chunks.len() as i32;
let conn = state.conn.clone();
let source_id_clone = source_id.clone();
let file_name_clone = file_name.clone();
let source_type_str = source_type.to_string();
let content_hash_clone = content_hash.clone();
let collection_clone = collection.clone();
let chunks_clone = chunks.clone();
let result = tokio::task::spawn_blocking(move || {
let mut db_conn = match conn.get() {
Ok(c) => c,
Err(e) => {
error!("DB connection error: {}", e);
return Err(format!("Database connection error: {}", e));
}
};
let insert_result = diesel::sql_query(
"INSERT INTO knowledge_sources (id, name, source_type, content_hash, chunk_count, status, collection, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, 'processing', $6, NOW(), NOW())
ON CONFLICT (id) DO NOTHING"
)
.bind::<diesel::sql_types::Text, _>(&source_id_clone)
.bind::<diesel::sql_types::Text, _>(&file_name_clone)
.bind::<diesel::sql_types::Text, _>(&source_type_str)
.bind::<diesel::sql_types::Text, _>(&content_hash_clone)
.bind::<diesel::sql_types::Integer, _>(chunk_count)
.bind::<diesel::sql_types::Text, _>(&collection_clone)
.execute(&mut db_conn);
if let Err(e) = insert_result {
error!("Failed to insert source: {}", e);
return Err(format!("Failed to insert source: {}", e));
}
for (idx, chunk_content) in chunks_clone.iter().enumerate() {
let chunk_id = Uuid::new_v4().to_string();
let token_count = estimate_tokens(chunk_content);
let chunk_result = diesel::sql_query(
"INSERT INTO knowledge_chunks (id, source_id, chunk_index, content, token_count, created_at)
VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (id) DO NOTHING"
)
.bind::<diesel::sql_types::Text, _>(&chunk_id)
.bind::<diesel::sql_types::Text, _>(&source_id_clone)
.bind::<diesel::sql_types::Integer, _>(idx as i32)
.bind::<diesel::sql_types::Text, _>(chunk_content)
.bind::<diesel::sql_types::Integer, _>(token_count)
.execute(&mut db_conn);
if let Err(e) = chunk_result {
warn!("Failed to insert chunk {}: {}", idx, e);
}
}
let _ = diesel::sql_query(
"UPDATE knowledge_sources SET status = 'indexed', indexed_at = NOW(), updated_at = NOW() WHERE id = $1"
)
.bind::<diesel::sql_types::Text, _>(&source_id_clone)
.execute(&mut db_conn);
Ok(())
})
.await;
match result {
Ok(Ok(())) => {
info!(
"Successfully ingested {} with {} chunks",
file_name, chunk_count
);
Json(UploadResponse {
success: true,
source_id: Some(source_id),
message: format!(
"Successfully ingested '{}' with {} chunks",
file_name, chunk_count
),
chunks_created: Some(chunk_count),
})
}
Ok(Err(e)) => Json(UploadResponse {
success: false,
source_id: None,
message: e,
chunks_created: None,
}),
Err(e) => Json(UploadResponse {
success: false,
source_id: None,
message: format!("Task error: {}", e),
chunks_created: None,
}),
}
}
pub async fn handle_list_sources(
State(state): State<Arc<AppState>>,
Query(params): Query<ListSourcesQuery>,
) -> impl IntoResponse {
let conn = state.conn.clone();
let status_filter = params.status.clone();
let type_filter = params.source_type.clone();
let page = params.page.unwrap_or(1).max(1);
let per_page = params.per_page.unwrap_or(20).min(100);
let offset = (page - 1) * per_page;
let sources = tokio::task::spawn_blocking(move || {
let mut db_conn = match conn.get() {
Ok(c) => c,
Err(e) => {
error!("DB connection error: {}", e);
return Vec::new();
}
};
let mut query = String::from(
"SELECT id, name, source_type, file_path, url, content_hash, chunk_count, status,
created_at::text, updated_at::text, indexed_at::text
FROM knowledge_sources WHERE 1=1",
);
if let Some(ref status) = status_filter {
let _ = write!(query, " AND status = '{}'", status.replace('\'', "''"));
}
if let Some(ref stype) = type_filter {
let _ = write!(query, " AND source_type = '{}'", stype.replace('\'', "''"));
}
let _ = write!(query, " ORDER BY created_at DESC LIMIT {per_page} OFFSET {offset}");
diesel::sql_query(&query)
.load::<KnowledgeSourceRow>(&mut db_conn)
.unwrap_or_default()
})
.await
.unwrap_or_default();
let mut html = String::new();
html.push_str("<div class=\"sources-list\">");
if sources.is_empty() {
html.push_str("<div class=\"empty-state\">");
html.push_str("<p>No knowledge sources found</p>");
html.push_str("<p class=\"hint\">Upload documents to build your knowledge base</p>");
html.push_str("</div>");
} else {
for source in &sources {
let status_class = match source.status.as_str() {
"indexed" => "status-success",
"processing" => "status-processing",
"failed" => "status-error",
"pending" => "status-pending",
_ => "status-unknown",
};
let type_icon = match source.source_type.as_str() {
"pdf" => "📄",
"docx" | "doc" => "📝",
"txt" | "text" => "📃",
"markdown" | "md" => "📋",
"html" => "🌐",
"csv" => "📊",
"xlsx" | "xls" => "📈",
"url" => "🔗",
_ => "📁",
};
html.push_str("<div class=\"source-item\" data-id=\"");
html.push_str(&html_escape(&source.id));
html.push_str("\">");
html.push_str("<div class=\"source-icon\">");
html.push_str(type_icon);
html.push_str("</div>");
html.push_str("<div class=\"source-info\">");
html.push_str("<h4 class=\"source-name\">");
html.push_str(&html_escape(&source.name));
html.push_str("</h4>");
html.push_str("<div class=\"source-meta\">");
html.push_str("<span class=\"source-type\">");
html.push_str(&html_escape(&source.source_type));
html.push_str("</span>");
html.push_str("<span class=\"source-chunks\">");
html.push_str(&source.chunk_count.to_string());
html.push_str(" chunks</span>");
html.push_str("<span class=\"source-status ");
html.push_str(status_class);
html.push_str("\">");
html.push_str(&html_escape(&source.status));
html.push_str("</span>");
html.push_str("</div>");
html.push_str("</div>");
html.push_str("<div class=\"source-actions\">");
html.push_str("<button class=\"btn-icon-sm\" title=\"Reindex\" ");
html.push_str("hx-post=\"/api/sources/kb/reindex\" ");
html.push_str("hx-vals='{\"source_ids\":[\"");
html.push_str(&html_escape(&source.id));
html.push_str("\"]}' hx-swap=\"none\">");
html.push_str("<svg width=\"14\" height=\"14\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\">");
html.push_str("<polyline points=\"23 4 23 10 17 10\"></polyline>");
html.push_str("<path d=\"M20.49 15a9 9 0 1 1-2.12-9.36L23 10\"></path>");
html.push_str("</svg></button>");
html.push_str("<button class=\"btn-icon-sm btn-danger\" title=\"Delete\" ");
html.push_str("hx-delete=\"/api/sources/kb/");
html.push_str(&html_escape(&source.id));
html.push_str("\" hx-confirm=\"Delete this source?\" hx-target=\"closest .source-item\" hx-swap=\"outerHTML\">");
html.push_str("<svg width=\"14\" height=\"14\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\">");
html.push_str("<polyline points=\"3 6 5 6 21 6\"></polyline>");
html.push_str("<path d=\"M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6m3 0V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2\"></path>");
html.push_str("</svg></button>");
html.push_str("</div>");
html.push_str("</div>");
}
}
html.push_str("</div>");
Html(html)
}
pub async fn handle_query_knowledge_base(
State(state): State<Arc<AppState>>,
Json(payload): Json<QueryRequest>,
) -> impl IntoResponse {
let start_time = std::time::Instant::now();
if payload.query.trim().is_empty() {
return Json(QueryResponse {
results: Vec::new(),
query: payload.query,
total_results: 0,
processing_time_ms: 0,
});
}
let conn = state.conn.clone();
let query = payload.query.clone();
let top_k = payload.top_k.unwrap_or(5).min(20);
let min_score = payload.min_score.unwrap_or(0.0);
let collection = payload.collection.clone();
let results = tokio::task::spawn_blocking(move || {
let mut db_conn = match conn.get() {
Ok(c) => c,
Err(e) => {
error!("DB connection error: {}", e);
return Vec::new();
}
};
let search_pattern = format!("%{}%", query.to_lowercase());
let mut sql = String::from(
"SELECT
kc.id as chunk_id,
kc.content,
kc.source_id,
ks.name as source_name,
kc.chunk_index,
CAST(ts_rank(to_tsvector('english', kc.content), plainto_tsquery('english', $1)) AS FLOAT4) as score
FROM knowledge_chunks kc
JOIN knowledge_sources ks ON kc.source_id = ks.id
WHERE ks.status = 'indexed'
AND (LOWER(kc.content) LIKE $2
OR to_tsvector('english', kc.content) @@ plainto_tsquery('english', $1))",
);
if collection.is_some() {
sql.push_str(" AND ks.collection = $3");
}
let _ = write!(sql, " ORDER BY score DESC, kc.chunk_index ASC LIMIT {top_k}");
let search_results: Vec<SearchResultRow> = if let Some(ref coll) = collection {
diesel::sql_query(&sql)
.bind::<diesel::sql_types::Text, _>(&query)
.bind::<diesel::sql_types::Text, _>(&search_pattern)
.bind::<diesel::sql_types::Text, _>(coll)
.load(&mut db_conn)
.unwrap_or_default()
} else {
diesel::sql_query(&sql)
.bind::<diesel::sql_types::Text, _>(&query)
.bind::<diesel::sql_types::Text, _>(&search_pattern)
.load(&mut db_conn)
.unwrap_or_default()
};
search_results
.into_iter()
.filter(|r| r.score >= min_score)
.map(|r| QueryResult {
content: r.content,
source_name: r.source_name,
source_id: r.source_id,
chunk_index: r.chunk_index,
score: r.score,
metadata: serde_json::json!({}),
})
.collect::<Vec<_>>()
})
.await
.unwrap_or_default();
let processing_time_ms = start_time.elapsed().as_millis() as u64;
Json(QueryResponse {
total_results: results.len(),
results,
query: payload.query,
processing_time_ms,
})
}
pub async fn handle_get_source(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> impl IntoResponse {
let conn = state.conn.clone();
let source = tokio::task::spawn_blocking(move || {
let mut db_conn = match conn.get() {
Ok(c) => c,
Err(e) => {
error!("DB connection error: {}", e);
return None;
}
};
let sources: Vec<KnowledgeSourceRow> = diesel::sql_query(
"SELECT id, name, source_type, file_path, url, content_hash, chunk_count, status,
created_at::text, updated_at::text, indexed_at::text
FROM knowledge_sources WHERE id = $1",
)
.bind::<diesel::sql_types::Text, _>(&id)
.load(&mut db_conn)
.unwrap_or_default();
sources.into_iter().next()
})
.await
.unwrap_or(None);
match source {
Some(s) => {
let mut html = String::new();
html.push_str("<div class=\"source-detail\">");
html.push_str("<h3>");
html.push_str(&html_escape(&s.name));
html.push_str("</h3>");
html.push_str("<div class=\"detail-grid\">");
html.push_str("<div class=\"detail-item\"><label>Type:</label><span>");
html.push_str(&html_escape(&s.source_type));
html.push_str("</span></div>");
html.push_str("<div class=\"detail-item\"><label>Status:</label><span>");
html.push_str(&html_escape(&s.status));
html.push_str("</span></div>");
html.push_str("<div class=\"detail-item\"><label>Chunks:</label><span>");
html.push_str(&s.chunk_count.to_string());
html.push_str("</span></div>");
html.push_str("<div class=\"detail-item\"><label>Created:</label><span>");
html.push_str(&html_escape(&s.created_at));
html.push_str("</span></div>");
if let Some(indexed) = &s.indexed_at {
html.push_str("<div class=\"detail-item\"><label>Indexed:</label><span>");
html.push_str(&html_escape(indexed));
html.push_str("</span></div>");
}
html.push_str("</div></div>");
Html(html)
}
None => Html("<div class=\"error\">Source not found</div>".to_string()),
}
}
pub async fn handle_delete_source(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> impl IntoResponse {
let conn = state.conn.clone();
let result = tokio::task::spawn_blocking(move || {
let mut db_conn = match conn.get() {
Ok(c) => c,
Err(e) => {
error!("DB connection error: {}", e);
return false;
}
};
let _ = diesel::sql_query("DELETE FROM knowledge_chunks WHERE source_id = $1")
.bind::<diesel::sql_types::Text, _>(&id)
.execute(&mut db_conn);
let delete_result = diesel::sql_query("DELETE FROM knowledge_sources WHERE id = $1")
.bind::<diesel::sql_types::Text, _>(&id)
.execute(&mut db_conn);
delete_result.is_ok()
})
.await
.unwrap_or(false);
if result {
Html("".to_string())
} else {
Html("<div class=\"error\">Failed to delete source</div>".to_string())
}
}
pub async fn handle_reindex_sources(
State(state): State<Arc<AppState>>,
Json(payload): Json<ReindexRequest>,
) -> impl IntoResponse {
let conn = state.conn.clone();
let source_ids = payload.source_ids.clone();
let result = tokio::task::spawn_blocking(move || {
let mut db_conn = match conn.get() {
Ok(c) => c,
Err(e) => {
error!("DB connection error: {}", e);
return 0;
}
};
let sql = if let Some(ids) = source_ids {
if ids.is_empty() {
return 0;
}
let placeholders: Vec<String> = ids.iter().map(|id| format!("'{}'", id.replace('\'', "''"))).collect();
format!(
"UPDATE knowledge_sources SET status = 'reindexing', updated_at = NOW() WHERE id IN ({})",
placeholders.join(",")
)
} else {
"UPDATE knowledge_sources SET status = 'reindexing', updated_at = NOW() WHERE status = 'indexed'".to_string()
};
diesel::sql_query(&sql)
.execute(&mut db_conn)
.unwrap_or(0)
})
.await
.unwrap_or(0);
Json(serde_json::json!({
"success": true,
"sources_queued": result,
"message": format!("{} sources queued for reindexing", result)
}))
}
pub async fn handle_get_stats(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let conn = state.conn.clone();
let stats = tokio::task::spawn_blocking(move || {
let mut db_conn = match conn.get() {
Ok(c) => c,
Err(e) => {
error!("DB connection error: {}", e);
return serde_json::json!({
"total_sources": 0,
"total_chunks": 0,
"indexed_sources": 0,
"pending_sources": 0,
"failed_sources": 0
});
}
};
#[derive(Debug, QueryableByName)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct CountRow {
#[diesel(sql_type = diesel::sql_types::BigInt)]
count: i64,
}
let total_sources: i64 = diesel::sql_query("SELECT COUNT(*) as count FROM knowledge_sources")
.load::<CountRow>(&mut db_conn)
.map(|v| v.first().map(|r| r.count).unwrap_or(0))
.unwrap_or(0);
let total_chunks: i64 = diesel::sql_query("SELECT COUNT(*) as count FROM knowledge_chunks")
.load::<CountRow>(&mut db_conn)
.map(|v| v.first().map(|r| r.count).unwrap_or(0))
.unwrap_or(0);
let indexed: i64 = diesel::sql_query("SELECT COUNT(*) as count FROM knowledge_sources WHERE status = 'indexed'")
.load::<CountRow>(&mut db_conn)
.map(|v| v.first().map(|r| r.count).unwrap_or(0))
.unwrap_or(0);
let pending: i64 = diesel::sql_query("SELECT COUNT(*) as count FROM knowledge_sources WHERE status IN ('pending', 'processing', 'reindexing')")
.load::<CountRow>(&mut db_conn)
.map(|v| v.first().map(|r| r.count).unwrap_or(0))
.unwrap_or(0);
let failed: i64 = diesel::sql_query("SELECT COUNT(*) as count FROM knowledge_sources WHERE status = 'failed'")
.load::<CountRow>(&mut db_conn)
.map(|v| v.first().map(|r| r.count).unwrap_or(0))
.unwrap_or(0);
serde_json::json!({
"total_sources": total_sources,
"total_chunks": total_chunks,
"indexed_sources": indexed,
"pending_sources": pending,
"failed_sources": failed
})
})
.await
.unwrap_or_else(|_| serde_json::json!({
"error": "Failed to fetch stats"
}));
Json(stats)
}
fn extract_text_content(
data: &[u8],
source_type: &SourceType,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
match source_type {
SourceType::Txt | SourceType::Markdown | SourceType::Csv => {
Ok(String::from_utf8_lossy(data).to_string())
}
SourceType::Html => {
let html = String::from_utf8_lossy(data);
Ok(strip_html_tags(&html))
}
SourceType::Pdf => {
#[cfg(feature = "drive")]
{
match pdf_extract::extract_text_from_mem(data) {
Ok(text) => Ok(text),
Err(e) => {
warn!("PDF extraction failed: {}", e);
Ok(String::new())
}
}
}
#[cfg(not(feature = "drive"))]
{
Err("PDF extraction not available without 'drive' feature".into())
}
}
SourceType::Docx => extract_docx_text(data),
SourceType::Xlsx => extract_xlsx_text(data),
_ => Ok(String::from_utf8_lossy(data).to_string()),
}
}
fn extract_docx_text(
data: &[u8],
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
use std::io::{Cursor, Read};
let cursor = Cursor::new(data);
let mut archive = match zip::ZipArchive::new(cursor) {
Ok(a) => a,
Err(e) => return Err(format!("Failed to open DOCX: {e}").into()),
};
let Ok(mut file) = archive.by_name("word/document.xml") else {
return Ok(String::new());
};
let mut xml = String::new();
file.read_to_string(&mut xml)?;
let mut in_text = false;
let mut result = String::new();
for part in xml.split('<') {
if part.starts_with("w:t") || part.starts_with("w:t ") {
in_text = true;
continue;
}
if part.starts_with("/w:t") {
in_text = false;
result.push(' ');
continue;
}
if part.starts_with("w:p") || part.starts_with("w:p ") {
result.push('\n');
continue;
}
if in_text {
if let Some(pos) = part.find('>') {
result.push_str(&part[pos + 1..]);
} else {
result.push_str(part);
}
}
}
Ok(result
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.collect::<Vec<_>>()
.join("\n"))
}
fn extract_xlsx_text(
data: &[u8],
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
use std::io::{Cursor, Read};
let cursor = Cursor::new(data);
let mut archive = match zip::ZipArchive::new(cursor) {
Ok(a) => a,
Err(e) => return Err(format!("Failed to open XLSX: {e}").into()),
};
let mut shared_strings: Vec<String> = Vec::new();
if let Ok(mut file) = archive.by_name("xl/sharedStrings.xml") {
let mut xml = String::new();
file.read_to_string(&mut xml)?;
for part in xml.split("<t") {
if let Some(start) = part.find('>') {
if let Some(end) = part[start..].find("</t>") {
let text = &part[start + 1..start + end];
shared_strings.push(text.to_string());
}
}
}
}
let mut content = String::new();
for i in 1..=10 {
let sheet_name = format!("xl/worksheets/sheet{i}.xml");
let Ok(mut file) = archive.by_name(&sheet_name) else {
break;
};
let mut xml = String::new();
file.read_to_string(&mut xml)?;
for part in xml.split("<v>") {
if let Some(end) = part.find("</v>") {
let value = &part[..end];
if let Ok(idx) = value.parse::<usize>() {
if let Some(text) = shared_strings.get(idx) {
content.push_str(text);
content.push('\t');
}
} else {
content.push_str(value);
content.push('\t');
}
}
}
content.push('\n');
}
Ok(content)
}
fn strip_html_tags(html: &str) -> String {
let mut result = String::new();
let mut in_tag = false;
for c in html.chars() {
match c {
'<' => in_tag = true,
'>' => in_tag = false,
_ if !in_tag => result.push(c),
_ => {}
}
}
result
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.collect::<Vec<_>>()
.join("\n")
}
fn compute_content_hash(content: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
let words: Vec<&str> = text.split_whitespace().collect();
if words.is_empty() {
return Vec::new();
}
if words.len() <= chunk_size {
return vec![words.join(" ")];
}
let mut chunks = Vec::new();
let mut start = 0;
while start < words.len() {
let end = (start + chunk_size).min(words.len());
let chunk: String = words[start..end].join(" ");
chunks.push(chunk);
if end >= words.len() {
break;
}
start = if overlap < chunk_size {
end - overlap
} else {
end
};
}
chunks
}
fn estimate_tokens(text: &str) -> i32 {
(text.split_whitespace().count() as f32 * 1.3) as i32
}
fn html_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#39;")
}