feat: Cancel streaming LLM when user sends new message
All checks were successful
BotServer CI/CD / build (push) Successful in 6m4s
All checks were successful
BotServer CI/CD / build (push) Successful in 6m4s
- Add active_streams HashMap to AppState to track streaming sessions - Create cancellation channel for each streaming session - Cancel existing streaming when new message arrives - Prevents overlapping responses and improves UX
This commit is contained in:
parent
01d4f47a93
commit
9db784fd5c
3 changed files with 79 additions and 45 deletions
|
|
@ -825,18 +825,29 @@ impl BotOrchestrator {
|
|||
// #[cfg(feature = "drive")]
|
||||
// set_llm_streaming(true);
|
||||
|
||||
let stream_tx_clone = stream_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = llm
|
||||
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
|
||||
.await
|
||||
{
|
||||
error!("LLM streaming error: {}", e);
|
||||
}
|
||||
// REMOVED: LLM streaming lock was causing deadlocks
|
||||
// #[cfg(feature = "drive")]
|
||||
// set_llm_streaming(false);
|
||||
});
|
||||
let stream_tx_clone = stream_tx.clone();
|
||||
|
||||
// Create cancellation channel for this streaming session
|
||||
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
|
||||
let session_id_str = session.id.to_string();
|
||||
|
||||
// Register this streaming session for potential cancellation
|
||||
{
|
||||
let mut active_streams = self.state.active_streams.lock().await;
|
||||
active_streams.insert(session_id_str.clone(), cancel_tx);
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = llm
|
||||
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
|
||||
.await
|
||||
{
|
||||
error!("LLM streaming error: {}", e);
|
||||
}
|
||||
// REMOVED: LLM streaming lock was causing deadlocks
|
||||
// #[cfg(feature = "drive")]
|
||||
// set_llm_streaming(false);
|
||||
});
|
||||
|
||||
let mut full_response = String::new();
|
||||
let mut analysis_buffer = String::new();
|
||||
|
|
@ -872,11 +883,17 @@ impl BotOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
while let Some(chunk) = stream_rx.recv().await {
|
||||
chunk_count += 1;
|
||||
if chunk_count <= 3 || chunk_count % 50 == 0 {
|
||||
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
|
||||
}
|
||||
while let Some(chunk) = stream_rx.recv().await {
|
||||
// Check if cancellation was requested (user sent new message)
|
||||
if cancel_rx.try_recv().is_ok() {
|
||||
info!("Streaming cancelled for session {} - user sent new message", session.id);
|
||||
break;
|
||||
}
|
||||
|
||||
chunk_count += 1;
|
||||
if chunk_count <= 3 || chunk_count % 50 == 0 {
|
||||
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
|
||||
}
|
||||
|
||||
// ===== GENERIC TOOL EXECUTION =====
|
||||
// Add chunk to tool_call_buffer and try to parse
|
||||
|
|
@ -1718,25 +1735,37 @@ let mut send_task = tokio::spawn(async move {
|
|||
};
|
||||
|
||||
if let Some(tx_clone) = tx_opt {
|
||||
let corrected_msg = UserMessage {
|
||||
bot_id: bot_id.to_string(),
|
||||
user_id: session.user_id.to_string(),
|
||||
session_id: session.id.to_string(),
|
||||
..user_msg
|
||||
};
|
||||
info!("Calling orchestrator for session {}", session_id);
|
||||
// CANCEL any existing streaming for this session first
|
||||
let session_id_str = session_id.to_string();
|
||||
{
|
||||
let mut active_streams = state_clone.active_streams.lock().await;
|
||||
if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
|
||||
info!("Cancelling existing streaming for session {}", session_id);
|
||||
let _ = cancel_tx.send(()).await;
|
||||
// Give a moment for the streaming to stop
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
let corrected_msg = UserMessage {
|
||||
bot_id: bot_id.to_string(),
|
||||
user_id: session.user_id.to_string(),
|
||||
session_id: session.id.to_string(),
|
||||
..user_msg
|
||||
};
|
||||
info!("Calling orchestrator for session {}", session_id);
|
||||
|
||||
// Spawn LLM in its own task so recv_task stays free to handle
|
||||
// new messages — prevents one hung LLM from locking the session.
|
||||
let orch = BotOrchestrator::new(state_clone.clone());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = orch
|
||||
.stream_response(corrected_msg, tx_clone)
|
||||
.await
|
||||
{
|
||||
error!("Failed to stream response: {}", e);
|
||||
}
|
||||
});
|
||||
// Spawn LLM in its own task so recv_task stays free to handle
|
||||
// new messages — prevents one hung LLM from locking the session.
|
||||
let orch = BotOrchestrator::new(state_clone.clone());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = orch
|
||||
.stream_response(corrected_msg, tx_clone)
|
||||
.await
|
||||
{
|
||||
error!("Failed to stream response: {}", e);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
warn!("Response channel NOT found for session: {}", session_id);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -398,6 +398,8 @@ pub struct AppState {
|
|||
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
|
||||
pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
|
||||
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||
/// Active streaming sessions for cancellation: session_id → cancellation sender
|
||||
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<()>>>>,
|
||||
/// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver.
|
||||
pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>,
|
||||
pub web_adapter: Arc<WebChannelAdapter>,
|
||||
|
|
@ -450,6 +452,7 @@ impl Clone for AppState {
|
|||
kb_manager: self.kb_manager.clone(),
|
||||
channels: Arc::clone(&self.channels),
|
||||
response_channels: Arc::clone(&self.response_channels),
|
||||
active_streams: Arc::clone(&self.active_streams),
|
||||
hear_channels: Arc::clone(&self.hear_channels),
|
||||
web_adapter: Arc::clone(&self.web_adapter),
|
||||
voice_adapter: Arc::clone(&self.voice_adapter),
|
||||
|
|
@ -665,6 +668,7 @@ impl Default for AppState {
|
|||
auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())),
|
||||
channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
||||
web_adapter: Arc::new(WebChannelAdapter::new()),
|
||||
voice_adapter: Arc::new(VoiceAdapter::new()),
|
||||
|
|
|
|||
|
|
@ -606,16 +606,17 @@ pub async fn create_app_state(
|
|||
dynamic_llm_provider: Some(dynamic_llm_provider.clone()),
|
||||
#[cfg(feature = "directory")]
|
||||
auth_service: auth_service.clone(),
|
||||
channels: Arc::new(tokio::sync::Mutex::new({
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"web".to_string(),
|
||||
web_adapter.clone() as Arc<dyn crate::core::bot::channels::ChannelAdapter>,
|
||||
);
|
||||
map
|
||||
})),
|
||||
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
||||
channels: Arc::new(tokio::sync::Mutex::new({
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"web".to_string(),
|
||||
web_adapter.clone() as Arc<dyn crate::core::bot::channels::ChannelAdapter>,
|
||||
);
|
||||
map
|
||||
})),
|
||||
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
||||
web_adapter: web_adapter.clone(),
|
||||
voice_adapter: voice_adapter.clone(),
|
||||
#[cfg(any(feature = "research", feature = "llm"))]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue