Add AzureGPT5 client and provider detection
Some checks failed
BotServer CI/CD / build (push) Has been cancelled
Some checks failed
BotServer CI/CD / build (push) Has been cancelled
- Add AzureGPT5Client struct for Responses API - Add AzureGPT5 to LLMProviderType enum - Detect provider via azuregpt5 or gpt5 in llm-provider config - Fix gpt_oss_120b.rs chars.peek() issue
This commit is contained in:
parent
de418e8fa7
commit
c603618865
3 changed files with 284 additions and 101 deletions
|
|
@ -326,20 +326,31 @@ impl LLMProvider for KimiClient {
|
|||
}
|
||||
}
|
||||
|
||||
// Kimi K2.5: content has the answer, reasoning/reasoning_content is thinking
|
||||
if let Some(text) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
if !text.is_empty() {
|
||||
let processed = handler.process_content_streaming(text, &mut stream_state);
|
||||
if !processed.is_empty() {
|
||||
total_content_chars += processed.len();
|
||||
if tx.send(processed).await.is_err() {
|
||||
info!("[Kimi] Channel closed, stopping stream after {} content chars", total_content_chars);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Kimi K2.5: content has the answer, reasoning/reasoning_content is thinking
|
||||
if let Some(text) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
if !text.is_empty() {
|
||||
let processed = handler.process_content_streaming(text, &mut stream_state);
|
||||
if !processed.is_empty() {
|
||||
total_content_chars += processed.len();
|
||||
if tx.send(processed).await.is_err() {
|
||||
info!("[Kimi] Channel closed, stopping stream after {} content chars", total_content_chars);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for content filter errors
|
||||
if let Some(filter_result) = delta.get("content_filter_result") {
|
||||
if let Some(error) = filter_result.get("error") {
|
||||
let code = error.get("code").and_then(|c| c.as_str()).unwrap_or("unknown");
|
||||
let message = error.get("message").and_then(|m| m.as_str()).unwrap_or("no message");
|
||||
error!("[Kimi] Content filter error: code={}, message={}", code, message);
|
||||
} else {
|
||||
log::trace!("[Kimi] Content filter result (no error): {:?}", filter_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str()) {
|
||||
if !reason.is_empty() {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
use super::ModelHandler;
|
||||
use log;
|
||||
|
||||
/// Handler for GPT-OSS 120B model with thinking tags filtering
|
||||
#[derive(Debug)]
|
||||
pub struct GptOss120bHandler {}
|
||||
|
||||
impl Default for GptOss120bHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
|
|
@ -15,70 +17,107 @@ impl GptOss120bHandler {
|
|||
}
|
||||
}
|
||||
|
||||
/// Extract content outside thinking tags
|
||||
/// If everything is inside thinking tags, extract from inside them
|
||||
fn strip_think_tags(content: &str) -> String {
|
||||
let result = content
|
||||
.replace("<think>", "")
|
||||
.replace("</think>", "")
|
||||
.replace("**start**", "")
|
||||
.replace("**end**", "");
|
||||
if result.is_empty() && !content.is_empty() {
|
||||
content.to_string()
|
||||
} else {
|
||||
result
|
||||
if content.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
let mut in_thinking = false;
|
||||
let mut thinking_content = String::new();
|
||||
let mut pos = 0;
|
||||
|
||||
while pos < content.len() {
|
||||
let remaining = &content[pos..];
|
||||
if !in_thinking {
|
||||
if remaining.starts_with("<thinking>") {
|
||||
in_thinking = true;
|
||||
thinking_content.clear();
|
||||
pos += 9;
|
||||
continue;
|
||||
} else if remaining.starts_with("**start**") {
|
||||
in_thinking = true;
|
||||
thinking_content.clear();
|
||||
pos += 8;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if remaining.starts_with("</thinking>") {
|
||||
in_thinking = false;
|
||||
pos += 11;
|
||||
continue;
|
||||
} else if remaining.starts_with("**end**") {
|
||||
in_thinking = false;
|
||||
pos += 6;
|
||||
continue;
|
||||
} else {
|
||||
thinking_content.push(content.chars().nth(pos).unwrap());
|
||||
pos += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if !in_thinking {
|
||||
result.push(content.chars().nth(pos).unwrap());
|
||||
}
|
||||
pos += 1;
|
||||
}
|
||||
|
||||
// If we got content outside thinking tags, return it
|
||||
if !result.trim().is_empty() {
|
||||
return result;
|
||||
}
|
||||
|
||||
// If everything was inside thinking tags, return that content
|
||||
if !thinking_content.trim().is_empty() {
|
||||
log::debug!("gpt_oss_120b: All content was in thinking tags, returning thinking content");
|
||||
return thinking_content;
|
||||
}
|
||||
|
||||
// Fallback: try regex extraction
|
||||
if let Ok(re) = regex::Regex::new(r"<thinking>(.*?)</thinking>") {
|
||||
let mut extracted = String::new();
|
||||
for cap in re.captures_iter(content) {
|
||||
if let Some(m) = cap.get(1) {
|
||||
if !extracted.is_empty() {
|
||||
extracted.push(' ');
|
||||
}
|
||||
extracted.push_str(m.as_str());
|
||||
}
|
||||
}
|
||||
if !extracted.is_empty() {
|
||||
return extracted;
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: return original
|
||||
log::warn!("gpt_oss_120b: Could not extract meaningful content, returning original");
|
||||
content.to_string()
|
||||
}
|
||||
|
||||
impl ModelHandler for GptOss120bHandler {
|
||||
fn is_analysis_complete(&self, buffer: &str) -> bool {
|
||||
buffer.contains("**end**") || buffer.contains("</think>")
|
||||
buffer.contains("**end**") || buffer.contains("</thinking>")
|
||||
}
|
||||
|
||||
fn process_content(&self, content: &str) -> String {
|
||||
strip_think_tags(content)
|
||||
}
|
||||
|
||||
fn process_content_streaming(&self, chunk: &str, state: &mut String) -> String {
|
||||
let old_len = state.len();
|
||||
state.push_str(chunk);
|
||||
|
||||
let mut clean_current = String::new();
|
||||
let mut in_think = false;
|
||||
// Process accumulated state and return new content since last call
|
||||
let processed = strip_think_tags(state);
|
||||
|
||||
let full_text = state.as_str();
|
||||
let mut current_pos = 0;
|
||||
|
||||
while current_pos < full_text.len() {
|
||||
if !in_think {
|
||||
if full_text[current_pos..].starts_with("<think>") {
|
||||
in_think = true;
|
||||
current_pos += 7;
|
||||
} else if full_text[current_pos..].starts_with("**start**") {
|
||||
current_pos += 10;
|
||||
} else if full_text[current_pos..].starts_with("**end**") {
|
||||
current_pos += 7;
|
||||
} else {
|
||||
let c = full_text[current_pos..].chars().next().unwrap();
|
||||
if current_pos >= old_len {
|
||||
clean_current.push(c);
|
||||
}
|
||||
current_pos += c.len_utf8();
|
||||
}
|
||||
} else {
|
||||
if full_text[current_pos..].starts_with("</think>") {
|
||||
in_think = false;
|
||||
current_pos += 8;
|
||||
} else {
|
||||
let c = full_text[current_pos..].chars().next().unwrap();
|
||||
current_pos += c.len_utf8();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if clean_current.is_empty() && chunk.len() > 0 {
|
||||
chunk.to_string()
|
||||
} else {
|
||||
clean_current
|
||||
}
|
||||
// For streaming, we return the entire processed content
|
||||
// The caller should handle deduplication if needed
|
||||
processed
|
||||
}
|
||||
|
||||
fn has_analysis_markers(&self, buffer: &str) -> bool {
|
||||
buffer.contains("**start**") || buffer.contains("<think>")
|
||||
buffer.contains("**start**") || buffer.contains("<thinking>")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
209
src/llm/mod.rs
209
src/llm/mod.rs
|
|
@ -58,6 +58,119 @@ pub struct OpenAIClient {
|
|||
rate_limiter: Arc<ApiRateLimiter>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AzureGPT5Client {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
api_version: String,
|
||||
rate_limiter: Arc<ApiRateLimiter>,
|
||||
}
|
||||
|
||||
impl AzureGPT5Client {
|
||||
pub fn new(base_url: String, api_version: Option<String>) -> Self {
|
||||
let api_version = api_version.unwrap_or_else(|| "2025-04-01-preview".to_string());
|
||||
let rate_limiter = Arc::new(ApiRateLimiter::unlimited());
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
api_version,
|
||||
rate_limiter,
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_utf8(input: &str) -> String {
|
||||
input.chars()
|
||||
.filter(|c| {
|
||||
let cp = *c as u32;
|
||||
!(0xD800..=0xDBFF).contains(&cp) && !(0xDC00..=0xDFFF).contains(&cp)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for AzureGPT5Client {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let raw_messages = if config.is_array() && !config.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
config
|
||||
} else {
|
||||
&serde_json::json!([{"role": "user", "content": prompt}])
|
||||
};
|
||||
|
||||
let full_url = format!(
|
||||
"{}/openai/responses?api-version={}",
|
||||
self.base_url, self.api_version
|
||||
);
|
||||
let auth_header = format!("Bearer {}", key);
|
||||
|
||||
let input_array: Vec<Value> = raw_messages
|
||||
.as_array()
|
||||
.unwrap_or(&vec![])
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
serde_json::json!({
|
||||
"role": msg.get("role").and_then(|r| r.as_str()).unwrap_or("user"),
|
||||
"content": Self::sanitize_utf8(msg.get("content").and_then(|c| c.as_str()).unwrap_or(""))
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&full_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"input": input_array,
|
||||
"max_output_tokens": 16384
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("AzureGPT5 generate error: {}", error_text);
|
||||
return Err(format!("AzureGPT5 request failed with status: {}", status).into());
|
||||
}
|
||||
|
||||
let result: Value = response.json().await?;
|
||||
let content = result["output"][0]["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or("");
|
||||
|
||||
Ok(content.to_string())
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
_tools: Option<&Vec<Value>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let content = self.generate(prompt, config, model, key).await?;
|
||||
tx.send(content).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Estimates token count for a text string (roughly 4 characters per token for English)
|
||||
fn estimate_tokens(text: &str) -> usize {
|
||||
|
|
@ -458,45 +571,56 @@ impl LLMProvider for OpenAIClient {
|
|||
last_bytes = chunk_str.chars().take(100).collect();
|
||||
for line in chunk_str.lines() {
|
||||
if line.starts_with("data: ") && !line.contains("[DONE]") {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
if let Some(reasoning) = data["choices"][0]["delta"]["reasoning_content"].as_str() {
|
||||
if !reasoning.is_empty() {
|
||||
if !in_reasoning {
|
||||
in_reasoning = true;
|
||||
}
|
||||
let thinking_msg = serde_json::json!({
|
||||
"type": "thinking",
|
||||
"content": reasoning
|
||||
}).to_string();
|
||||
let _ = tx.send(thinking_msg).await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
|
||||
if !content.is_empty() {
|
||||
if in_reasoning {
|
||||
in_reasoning = false;
|
||||
let clear_msg = serde_json::json!({"type": "thinking_clear"}).to_string();
|
||||
let _ = tx.send(clear_msg).await;
|
||||
}
|
||||
let processed = handler.process_content_streaming(content, &mut stream_state);
|
||||
if !processed.is_empty() {
|
||||
content_sent += processed.len();
|
||||
let _ = tx.send(processed).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = data["choices"][0]["delta"]["tool_calls"].as_array() {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(func) = tool_call.get("function") {
|
||||
if let Some(args) = func.get("arguments").and_then(|a| a.as_str()) {
|
||||
let _ = tx.send(args.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
// Check for content filter errors
|
||||
if let Some(filter_result) = data["choices"][0]["delta"]["content_filter_result"].as_object() {
|
||||
if let Some(error) = filter_result.get("error") {
|
||||
let code = error.get("code").and_then(|c| c.as_str()).unwrap_or("unknown");
|
||||
let message = error.get("message").and_then(|m| m.as_str()).unwrap_or("no message");
|
||||
error!("LLM Content filter error: code={}, message={}", code, message);
|
||||
} else {
|
||||
trace!("LLM Content filter result (no error): {:?}", filter_result);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = data["choices"][0]["delta"]["reasoning_content"].as_str() {
|
||||
if !reasoning.is_empty() {
|
||||
if !in_reasoning {
|
||||
in_reasoning = true;
|
||||
}
|
||||
let thinking_msg = serde_json::json!({
|
||||
"type": "thinking",
|
||||
"content": reasoning
|
||||
}).to_string();
|
||||
let _ = tx.send(thinking_msg).await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
|
||||
if !content.is_empty() {
|
||||
if in_reasoning {
|
||||
in_reasoning = false;
|
||||
let clear_msg = serde_json::json!({"type": "thinking_clear"}).to_string();
|
||||
let _ = tx.send(clear_msg).await;
|
||||
}
|
||||
let processed = handler.process_content_streaming(content, &mut stream_state);
|
||||
if !processed.is_empty() {
|
||||
content_sent += processed.len();
|
||||
let _ = tx.send(processed).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = data["choices"][0]["delta"]["tool_calls"].as_array() {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(func) = tool_call.get("function") {
|
||||
if let Some(args) = func.get("arguments").and_then(|a| a.as_str()) {
|
||||
let _ = tx.send(args.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -525,6 +649,7 @@ pub enum LLMProviderType {
|
|||
OpenAI,
|
||||
Claude,
|
||||
AzureClaude,
|
||||
AzureGPT5,
|
||||
GLM,
|
||||
Bedrock,
|
||||
Vertex,
|
||||
|
|
@ -539,6 +664,10 @@ impl From<&str> for LLMProviderType {
|
|||
} else {
|
||||
Self::Claude
|
||||
}
|
||||
} else if lower.contains("azuregpt5") || lower.contains("gpt5") {
|
||||
Self::AzureGPT5
|
||||
} else if lower.contains("openai.azure.com") && lower.contains("responses") {
|
||||
Self::AzureGPT5
|
||||
} else if lower.contains("z.ai") || lower.contains("glm") {
|
||||
Self::GLM
|
||||
} else if lower.contains("bedrock") {
|
||||
|
|
@ -578,6 +707,10 @@ pub fn create_llm_provider(
|
|||
);
|
||||
std::sync::Arc::new(ClaudeClient::azure(base_url, deployment))
|
||||
}
|
||||
LLMProviderType::AzureGPT5 => {
|
||||
info!("Creating Azure GPT-5/Responses LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(AzureGPT5Client::new(base_url, endpoint_path))
|
||||
}
|
||||
LLMProviderType::GLM => {
|
||||
info!("Creating GLM/z.ai LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(GLMClient::new(base_url))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue