feat: Cancel streaming LLM when user sends new message
All checks were successful
BotServer CI/CD / build (push) Successful in 6m4s
All checks were successful
BotServer CI/CD / build (push) Successful in 6m4s
- Add active_streams HashMap to AppState to track streaming sessions - Create cancellation channel for each streaming session - Cancel existing streaming when new message arrives - Prevents overlapping responses and improves UX
This commit is contained in:
parent
01d4f47a93
commit
9db784fd5c
3 changed files with 79 additions and 45 deletions
|
|
@ -826,6 +826,17 @@ impl BotOrchestrator {
|
||||||
// set_llm_streaming(true);
|
// set_llm_streaming(true);
|
||||||
|
|
||||||
let stream_tx_clone = stream_tx.clone();
|
let stream_tx_clone = stream_tx.clone();
|
||||||
|
|
||||||
|
// Create cancellation channel for this streaming session
|
||||||
|
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
|
||||||
|
let session_id_str = session.id.to_string();
|
||||||
|
|
||||||
|
// Register this streaming session for potential cancellation
|
||||||
|
{
|
||||||
|
let mut active_streams = self.state.active_streams.lock().await;
|
||||||
|
active_streams.insert(session_id_str.clone(), cancel_tx);
|
||||||
|
}
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = llm
|
if let Err(e) = llm
|
||||||
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
|
.generate_stream("", &messages_clone, stream_tx_clone, &model_clone, &key_clone, tools_for_llm.as_ref())
|
||||||
|
|
@ -873,6 +884,12 @@ impl BotOrchestrator {
|
||||||
}
|
}
|
||||||
|
|
||||||
while let Some(chunk) = stream_rx.recv().await {
|
while let Some(chunk) = stream_rx.recv().await {
|
||||||
|
// Check if cancellation was requested (user sent new message)
|
||||||
|
if cancel_rx.try_recv().is_ok() {
|
||||||
|
info!("Streaming cancelled for session {} - user sent new message", session.id);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
chunk_count += 1;
|
chunk_count += 1;
|
||||||
if chunk_count <= 3 || chunk_count % 50 == 0 {
|
if chunk_count <= 3 || chunk_count % 50 == 0 {
|
||||||
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
|
info!("LLM chunk #{chunk_count} received for session {} (len={})", session.id, chunk.len());
|
||||||
|
|
@ -1718,6 +1735,18 @@ let mut send_task = tokio::spawn(async move {
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(tx_clone) = tx_opt {
|
if let Some(tx_clone) = tx_opt {
|
||||||
|
// CANCEL any existing streaming for this session first
|
||||||
|
let session_id_str = session_id.to_string();
|
||||||
|
{
|
||||||
|
let mut active_streams = state_clone.active_streams.lock().await;
|
||||||
|
if let Some(cancel_tx) = active_streams.remove(&session_id_str) {
|
||||||
|
info!("Cancelling existing streaming for session {}", session_id);
|
||||||
|
let _ = cancel_tx.send(()).await;
|
||||||
|
// Give a moment for the streaming to stop
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let corrected_msg = UserMessage {
|
let corrected_msg = UserMessage {
|
||||||
bot_id: bot_id.to_string(),
|
bot_id: bot_id.to_string(),
|
||||||
user_id: session.user_id.to_string(),
|
user_id: session.user_id.to_string(),
|
||||||
|
|
|
||||||
|
|
@ -398,6 +398,8 @@ pub struct AppState {
|
||||||
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
|
pub auth_service: Arc<tokio::sync::Mutex<AuthService>>,
|
||||||
pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
|
pub channels: Arc<tokio::sync::Mutex<HashMap<String, Arc<dyn ChannelAdapter>>>>,
|
||||||
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
pub response_channels: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||||
|
/// Active streaming sessions for cancellation: session_id → cancellation sender
|
||||||
|
pub active_streams: Arc<tokio::sync::Mutex<HashMap<String, mpsc::Sender<()>>>>,
|
||||||
/// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver.
|
/// Blocking channels for HEAR: session_id → sender. Rhai thread blocks on receiver.
|
||||||
pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>,
|
pub hear_channels: Arc<std::sync::Mutex<HashMap<uuid::Uuid, std::sync::mpsc::SyncSender<String>>>>,
|
||||||
pub web_adapter: Arc<WebChannelAdapter>,
|
pub web_adapter: Arc<WebChannelAdapter>,
|
||||||
|
|
@ -450,6 +452,7 @@ impl Clone for AppState {
|
||||||
kb_manager: self.kb_manager.clone(),
|
kb_manager: self.kb_manager.clone(),
|
||||||
channels: Arc::clone(&self.channels),
|
channels: Arc::clone(&self.channels),
|
||||||
response_channels: Arc::clone(&self.response_channels),
|
response_channels: Arc::clone(&self.response_channels),
|
||||||
|
active_streams: Arc::clone(&self.active_streams),
|
||||||
hear_channels: Arc::clone(&self.hear_channels),
|
hear_channels: Arc::clone(&self.hear_channels),
|
||||||
web_adapter: Arc::clone(&self.web_adapter),
|
web_adapter: Arc::clone(&self.web_adapter),
|
||||||
voice_adapter: Arc::clone(&self.voice_adapter),
|
voice_adapter: Arc::clone(&self.voice_adapter),
|
||||||
|
|
@ -665,6 +668,7 @@ impl Default for AppState {
|
||||||
auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())),
|
auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())),
|
||||||
channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
|
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
||||||
web_adapter: Arc::new(WebChannelAdapter::new()),
|
web_adapter: Arc::new(WebChannelAdapter::new()),
|
||||||
voice_adapter: Arc::new(VoiceAdapter::new()),
|
voice_adapter: Arc::new(VoiceAdapter::new()),
|
||||||
|
|
|
||||||
|
|
@ -615,6 +615,7 @@ pub async fn create_app_state(
|
||||||
map
|
map
|
||||||
})),
|
})),
|
||||||
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
|
active_streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
hear_channels: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
||||||
web_adapter: web_adapter.clone(),
|
web_adapter: web_adapter.clone(),
|
||||||
voice_adapter: voice_adapter.clone(),
|
voice_adapter: voice_adapter.clone(),
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue