feat: Cancel streaming LLM when user sends new message
All checks were successful
BotServer CI/CD / build (push) Successful in 6m4s

- Add active_streams HashMap to AppState to track streaming sessions
- Create cancellation channel for each streaming session
- Cancel existing streaming when new message arrives
- Prevents overlapping responses and improves UX
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-04-15 07:37:07 -03:00
parent 01d4f47a93
commit 9db784fd5c
3 changed files with 79 additions and 45 deletions

View file

@ -825,18 +825,29 @@ impl BotOrchestrator {
// #[cfg(feature = "drive")] // #[cfg(feature = "drive")]
// set_llm_streaming(true); // set_llm_streaming(true);
let stream_tx_clone = stream_tx.clone(); let stream_tx_clone = stream_tx.clone();
tokio::spawn(async move {
if let Err(e) = llm // Create cancellation channel for this streaming session
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref()) let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
.await let session_id_str = session.id.to_string();
{
error!("LLM streaming error: {}", e); // Register this streaming session for potential cancellation
} {
// REMOVED: LLM streaming lock was causing deadlocks let mut active_streams = self.state.active_streams.lock().await;
// #[cfg(feature = "drive")] active_streams.insert(session_id_str.clone(), cancel_tx);
// set_llm_streaming(false); }
});
tokio::spawn(async move {
if let Err(e) = llm
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
.await
{
error!("LLM streaming error: {}", e);
}
// REMOVED: LLM streaming lock was causing deadlocks
// #[cfg(feature = "drive")]
// set_llm_streaming(false);
});
let mut full_response = String::new(); let mut full_response = String::new();
let mut analysis_buffer = String::new(); let mut analysis_buffer = String::new();
@ -872,11 +883,17 @@ impl BotOrchestrator {
} }
} }
while let Some(chunk) = stream_rx.recv().await { while let Some(chunk) = stream_rx.recv().await {
chunk_count += 1; // Check if cancellation was requested (user sent new message)
if chunk_count <= 3 || chunk_count % 50 == 0 { if cancel_rx.try_recv().is_ok() {
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len()); info!("Streaming cancelled for session {} - user sent new message", session.id);
} break;
}
chunk_count += 1;
if chunk_count <= 3 || chunk_count % 50 == 0 {
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
}
// ===== GENERIC TOOL EXECUTION ===== // ===== GENERIC TOOL EXECUTION =====
// Add chunk to tool_call_buffer and try to parse // Add chunk to tool_call_buffer and try to parse
@ -1718,25 +1735,37 @@ let mut send_task = tokio::spawn(async move {
}; };
if let Some(tx_clone) = tx_opt { if let Some(tx_clone) = tx_opt {
let corrected_msg = UserMessage { // CANCEL any existing streaming for this session first
bot_id: bot_id.to_string(), let session_id_str = session_id.to_string();
user_id: session.user_id.to_string(), {
session_id: session.id.to_string(), let mut active_streams = state_clone.active_streams.lock().await;
..user_msg if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
}; info!("Cancelling existing streaming for session {}", session_id);
info!("Calling orchestrator for session {}", session_id); let _ = cancel_tx.send(()).await;
// Give a moment for the streaming to stop
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
}
let corrected_msg = UserMessage {
bot_id: bot_id.to_string(),
user_id: session.user_id.to_string(),
session_id: session.id.to_string(),
..user_msg
};
info!("Calling orchestrator for session {}", session_id);
// Spawn LLM in its own task so recv_task stays free to handle // Spawn LLM in its own task so recv_task stays free to handle
// new messages — prevents one hung LLM from locking the session. // new messages — prevents one hung LLM from locking the session.
let orch = BotOrchestrator::new(state_clone.clone()); let orch = BotOrchestrator::new(state_clone.clone());
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = orch if let Err(e) = orch
.stream_response(corrected_msg, tx_clone) .stream_response(corrected_msg, tx_clone)
.await .await
{ {
error!("Failed to stream response: {}", e); error!("Failed to stream response: {}", e);
} }
}); });
} else { } else {
warn!("Response channel NOT found for session: {}", session_id); warn!("Response channel NOT found for session: {}", session_id);
} }

View file

@ -398,6 +398,8 @@ pub struct AppState {
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>, pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>, pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>, pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
/// Active streaming sessions for cancellation: session_id → cancellation sender
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<()>>>>,
/// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver. /// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver.
pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>, pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>,
pub web_adapter: Arc<WebChannelAdapter>, pub web_adapter: Arc<WebChannelAdapter>,
@ -450,6 +452,7 @@ impl Clone for AppState {
kb_manager: self.kb_manager.clone(), kb_manager: self.kb_manager.clone(),
channels: Arc::clone(&self.channels), channels: Arc::clone(&self.channels),
response_channels: Arc::clone(&self.response_channels), response_channels: Arc::clone(&self.response_channels),
active_streams: Arc::clone(&self.active_streams),
hear_channels: Arc::clone(&self.hear_channels), hear_channels: Arc::clone(&self.hear_channels),
web_adapter: Arc::clone(&self.web_adapter), web_adapter: Arc::clone(&self.web_adapter),
voice_adapter: Arc::clone(&self.voice_adapter), voice_adapter: Arc::clone(&self.voice_adapter),
@ -665,6 +668,7 @@ impl Default for AppState {
auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())), auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())),
channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())), hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
web_adapter: Arc::new(WebChannelAdapter::new()), web_adapter: Arc::new(WebChannelAdapter::new()),
voice_adapter: Arc::new(VoiceAdapter::new()), voice_adapter: Arc::new(VoiceAdapter::new()),

View file

@ -606,16 +606,17 @@ pub async fn create_app_state(
dynamic_llm_provider: Some(dynamic_llm_provider.clone()), dynamic_llm_provider: Some(dynamic_llm_provider.clone()),
#[cfg(feature = "directory")] #[cfg(feature = "directory")]
auth_service: auth_service.clone(), auth_service: auth_service.clone(),
channels: Arc::new(tokio::sync::Mutex::new({ channels: Arc::new(tokio::sync::Mutex::new({
let mut map = HashMap::new(); let mut map = HashMap::new();
map.insert( map.insert(
"web".to_string(), "web".to_string(),
web_adapter.clone() as Arc<dyn crate::core::bot::channels::ChannelAdapter>, web_adapter.clone() as Arc<dyn crate::core::bot::channels::ChannelAdapter>,
); );
map map
})), })),
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())), active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
web_adapter: web_adapter.clone(), web_adapter: web_adapter.clone(),
voice_adapter: voice_adapter.clone(), voice_adapter: voice_adapter.clone(),
#[cfg(any(feature = "research", feature = "llm"))] #[cfg(any(feature = "research", feature = "llm"))]