diff --git a/src/basic/keywords/llm_keyword.rs b/src/basic/keywords/llm_keyword.rs index 4204c4e9..7b11a5f5 100644 --- a/src/basic/keywords/llm_keyword.rs +++ b/src/basic/keywords/llm_keyword.rs @@ -18,40 +18,25 @@ pub fn llm_keyword(state: Arc, _user: UserSession, engine: &mut Engine let text = context .eval_expression_tree(first_input)? .to_string(); - let state_for_thread = Arc::clone(&state_clone); + let state_for_async = Arc::clone(&state_clone); let prompt = build_llm_prompt(&text); - let (tx, rx) = std::sync::mpsc::channel(); - std::thread::spawn(move || { - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(2) - .enable_all() - .build(); - let send_err = if let Ok(rt) = rt { - let result = rt.block_on(async move { - execute_llm_generation(state_for_thread, prompt).await - }); - tx.send(result).err() - } else { - tx.send(Err("failed to build tokio runtime".into())).err() - }; - if send_err.is_some() { - error!("Failed to send LLM thread result"); - } + + let handle = tokio::runtime::Handle::current(); + let result = handle.block_on(async move { + tokio::time::timeout( + Duration::from_secs(45), + execute_llm_generation(state_for_async, prompt) + ).await }); - match rx.recv_timeout(Duration::from_secs(500)) { - Ok(Ok(result)) => Ok(Dynamic::from(result)), + + match result { + Ok(Ok(output)) => Ok(Dynamic::from(output)), Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime( e.to_string().into(), rhai::Position::NONE, ))), - Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { - Err(Box::new(rhai::EvalAltResult::ErrorRuntime( - "LLM generation timed out".into(), - rhai::Position::NONE, - ))) - } - Err(e) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime( - format!("LLM thread failed: {e}").into(), + Err(_) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime( + "LLM generation timed out after 45 seconds".into(), rhai::Position::NONE, ))), }