fix: GLM client - add chat_template_kwargs, handle reasoning_content, increase max_tokens to 16384
All checks were successful
BotServer CI/CD / build (push) Successful in 5m52s
All checks were successful
BotServer CI/CD / build (push) Successful in 5m52s
This commit is contained in:
parent
8a65afbfc5
commit
87df733db0
1 changed files with 92 additions and 74 deletions
164
src/llm/glm.rs
164
src/llm/glm.rs
|
|
@ -1,16 +1,12 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use log::{error, info};
|
use log::{error, info, trace};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use super::LLMProvider;
|
use super::LLMProvider;
|
||||||
|
|
||||||
// GLM / z.ai API Client
|
|
||||||
// Similar to OpenAI but with different endpoint structure
|
|
||||||
// For z.ai, base URL already contains version (e.g., /v4), endpoint is just /chat/completions
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct GLMMessage {
|
pub struct GLMMessage {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
|
@ -20,6 +16,12 @@ pub struct GLMMessage {
|
||||||
pub tool_calls: Option<Vec<Value>>,
|
pub tool_calls: Option<Vec<Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct GLMChatTemplateKwargs {
|
||||||
|
pub enable_thinking: bool,
|
||||||
|
pub clear_thinking: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct GLMRequest {
|
pub struct GLMRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
@ -36,6 +38,8 @@ pub struct GLMRequest {
|
||||||
pub tools: Option<Vec<Value>>,
|
pub tools: Option<Vec<Value>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_choice: Option<Value>,
|
pub tool_choice: Option<Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub chat_template_kwargs: Option<GLMChatTemplateKwargs>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -58,7 +62,6 @@ pub struct GLMResponse {
|
||||||
pub usage: Option<Value>,
|
pub usage: Option<Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Streaming structures
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
pub struct GLMStreamDelta {
|
pub struct GLMStreamDelta {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
|
@ -75,7 +78,6 @@ pub struct GLMStreamDelta {
|
||||||
pub struct GLMStreamChoice {
|
pub struct GLMStreamChoice {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
#[serde(default)]
|
|
||||||
pub delta: GLMStreamDelta,
|
pub delta: GLMStreamDelta,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
|
|
@ -116,7 +118,6 @@ impl GLMClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sanitizes a string by removing invalid UTF-8 surrogate characters
|
|
||||||
fn sanitize_utf8(input: &str) -> String {
|
fn sanitize_utf8(input: &str) -> String {
|
||||||
input.chars()
|
input.chars()
|
||||||
.filter(|c| {
|
.filter(|c| {
|
||||||
|
|
@ -142,26 +143,29 @@ impl LLMProvider for GLMClient {
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// NVIDIA API uses z-ai/glm4.7 as the model identifier
|
|
||||||
let model_name = if model == "glm-4" || model == "glm-4.7" {
|
let model_name = if model == "glm-4" || model == "glm-4.7" {
|
||||||
"z-ai/glm4.7"
|
"z-ai/glm4.7"
|
||||||
} else {
|
} else {
|
||||||
model
|
model
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = GLMRequest {
|
let request = GLMRequest {
|
||||||
model: model_name.to_string(),
|
model: model_name.to_string(),
|
||||||
messages,
|
messages,
|
||||||
stream: Some(false),
|
stream: Some(false),
|
||||||
max_tokens: None,
|
max_tokens: Some(16384),
|
||||||
temperature: Some(1.0),
|
temperature: Some(1.0),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
tools: None,
|
tools: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
};
|
chat_template_kwargs: Some(GLMChatTemplateKwargs {
|
||||||
|
enable_thinking: true,
|
||||||
|
clear_thinking: false,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
let url = self.build_url();
|
let url = self.build_url();
|
||||||
info!("GLM non-streaming request to: {}", url);
|
info!("[GLM] Non-streaming request to: {}", url);
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
|
|
@ -174,7 +178,7 @@ impl LLMProvider for GLMClient {
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
error!("GLM API error: {}", error_text);
|
error!("[GLM] API error: {}", error_text);
|
||||||
return Err(format!("GLM API error: {}", error_text).into());
|
return Err(format!("GLM API error: {}", error_text).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -197,15 +201,12 @@ impl LLMProvider for GLMClient {
|
||||||
key: &str,
|
key: &str,
|
||||||
tools: Option<&Vec<Value>>,
|
tools: Option<&Vec<Value>>,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
// config IS the messages array directly, not nested
|
|
||||||
let messages = if let Some(msgs) = config.as_array() {
|
let messages = if let Some(msgs) = config.as_array() {
|
||||||
// Convert messages from config format to GLM format
|
|
||||||
msgs.iter()
|
msgs.iter()
|
||||||
.filter_map(|m| {
|
.filter_map(|m| {
|
||||||
let role = m.get("role")?.as_str()?;
|
let role = m.get("role")?.as_str()?;
|
||||||
let content = m.get("content")?.as_str()?;
|
let content = m.get("content")?.as_str()?;
|
||||||
let sanitized = Self::sanitize_utf8(content);
|
let sanitized = Self::sanitize_utf8(content);
|
||||||
// NVIDIA API accepts empty content, don't filter them out
|
|
||||||
Some(GLMMessage {
|
Some(GLMMessage {
|
||||||
role: role.to_string(),
|
role: role.to_string(),
|
||||||
content: Some(sanitized),
|
content: Some(sanitized),
|
||||||
|
|
@ -214,7 +215,6 @@ impl LLMProvider for GLMClient {
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
} else {
|
} else {
|
||||||
// Fallback to building from prompt
|
|
||||||
vec![GLMMessage {
|
vec![GLMMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some(Self::sanitize_utf8(prompt)),
|
content: Some(Self::sanitize_utf8(prompt)),
|
||||||
|
|
@ -222,39 +222,39 @@ impl LLMProvider for GLMClient {
|
||||||
}]
|
}]
|
||||||
};
|
};
|
||||||
|
|
||||||
// If no messages, return error
|
|
||||||
if messages.is_empty() {
|
if messages.is_empty() {
|
||||||
return Err("No valid messages in request".into());
|
return Err("No valid messages in request".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
// NVIDIA API uses z-ai/glm4.7 as the model identifier
|
|
||||||
// GLM-4.7 supports standard OpenAI-compatible function calling
|
|
||||||
let model_name = if model == "glm-4" || model == "glm-4.7" {
|
let model_name = if model == "glm-4" || model == "glm-4.7" {
|
||||||
"z-ai/glm4.7"
|
"z-ai/glm4.7"
|
||||||
} else {
|
} else {
|
||||||
model
|
model
|
||||||
};
|
};
|
||||||
|
|
||||||
// Set tool_choice to "auto" when tools are present - this tells GLM to automatically decide when to call a tool
|
|
||||||
let tool_choice = if tools.is_some() {
|
let tool_choice = if tools.is_some() {
|
||||||
Some(serde_json::json!("auto"))
|
Some(serde_json::json!("auto"))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = GLMRequest {
|
let request = GLMRequest {
|
||||||
model: model_name.to_string(),
|
model: model_name.to_string(),
|
||||||
messages,
|
messages,
|
||||||
stream: Some(true),
|
stream: Some(true),
|
||||||
max_tokens: None,
|
max_tokens: Some(16384),
|
||||||
temperature: Some(1.0),
|
temperature: Some(1.0),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
tools: tools.cloned(),
|
tools: tools.cloned(),
|
||||||
tool_choice,
|
tool_choice,
|
||||||
};
|
chat_template_kwargs: Some(GLMChatTemplateKwargs {
|
||||||
|
enable_thinking: true,
|
||||||
|
clear_thinking: false,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
let url = self.build_url();
|
let url = self.build_url();
|
||||||
info!("GLM streaming request to: {}", url);
|
info!("[GLM] Streaming request to: {}", url);
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
|
|
@ -267,29 +267,28 @@ impl LLMProvider for GLMClient {
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
error!("GLM streaming error: {}", error_text);
|
error!("[GLM] Streaming error: {}", error_text);
|
||||||
return Err(format!("GLM streaming error: {}", error_text).into());
|
return Err(format!("GLM streaming error: {}", error_text).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut in_reasoning = false;
|
||||||
|
let mut has_sent_thinking = false;
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
let chunk = chunk_result.map_err(|e| format!("Stream error: {}", e))?;
|
let chunk = chunk_result.map_err(|e| format!("Stream error: {}", e))?;
|
||||||
|
|
||||||
buffer.extend_from_slice(&chunk);
|
buffer.extend_from_slice(&chunk);
|
||||||
let data = String::from_utf8_lossy(&buffer);
|
let data = String::from_utf8_lossy(&buffer);
|
||||||
|
|
||||||
// Process SSE lines
|
|
||||||
for line in data.lines() {
|
for line in data.lines() {
|
||||||
let line = line.trim();
|
let line = line.trim();
|
||||||
|
|
||||||
if line.is_empty() {
|
if line.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if line == "data: [DONE]" {
|
if line == "data: [DONE]" {
|
||||||
std::mem::drop(tx.send(String::new())); // Signal end
|
let _ = tx.send(String::new()).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -299,43 +298,64 @@ impl LLMProvider for GLMClient {
|
||||||
if let Some(choices) = chunk_data.get("choices").and_then(|c| c.as_array()) {
|
if let Some(choices) = chunk_data.get("choices").and_then(|c| c.as_array()) {
|
||||||
for choice in choices {
|
for choice in choices {
|
||||||
if let Some(delta) = choice.get("delta") {
|
if let Some(delta) = choice.get("delta") {
|
||||||
// Handle tool_calls (GLM-4.7 standard function calling)
|
// Handle tool_calls
|
||||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
// Send tool_calls as JSON for the calling code to process
|
|
||||||
let tool_call_json = serde_json::json!({
|
let tool_call_json = serde_json::json!({
|
||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
"content": tool_call
|
"content": tool_call
|
||||||
}).to_string();
|
}).to_string();
|
||||||
match tx.send(tool_call_json).await {
|
let _ = tx.send(tool_call_json).await;
|
||||||
Ok(_) => {},
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to send tool_call to channel: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GLM-4.7 on NVIDIA sends thinking text via reasoning_content
|
// Handle reasoning_content (thinking phase)
|
||||||
// The actual user-facing response is in content field
|
let reasoning = delta.get("reasoning_content")
|
||||||
// We ONLY send content — never reasoning_content (internal thinking)
|
.and_then(|r| r.as_str())
|
||||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
.or_else(|| delta.get("reasoning").and_then(|r| r.as_str()));
|
||||||
if !content.is_empty() {
|
|
||||||
match tx.send(content.to_string()).await {
|
let content = delta.get("content").and_then(|c| c.as_str());
|
||||||
Ok(_) => {},
|
|
||||||
Err(e) => {
|
// Enter reasoning mode
|
||||||
error!("Failed to send to channel: {}", e);
|
if reasoning.is_some() && content.is_none() {
|
||||||
}
|
if !in_reasoning {
|
||||||
}
|
trace!("[GLM] Entering reasoning/thinking mode");
|
||||||
|
in_reasoning = true;
|
||||||
|
}
|
||||||
|
if !has_sent_thinking {
|
||||||
|
let thinking = serde_json::json!({
|
||||||
|
"type": "thinking",
|
||||||
|
"content": "\u{1f914} Pensando..."
|
||||||
|
}).to_string();
|
||||||
|
let _ = tx.send(thinking).await;
|
||||||
|
has_sent_thinking = true;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exited reasoning — content is now real response
|
||||||
|
if in_reasoning && content.is_some() {
|
||||||
|
trace!("[GLM] Exited reasoning mode");
|
||||||
|
in_reasoning = false;
|
||||||
|
let clear = serde_json::json!({
|
||||||
|
"type": "thinking_clear",
|
||||||
|
"content": ""
|
||||||
|
}).to_string();
|
||||||
|
let _ = tx.send(clear).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send actual content to user
|
||||||
|
if let Some(text) = content {
|
||||||
|
if !text.is_empty() {
|
||||||
|
let _ = tx.send(text.to_string()).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// No delta in choice
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str()) {
|
if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str()) {
|
||||||
if !reason.is_empty() {
|
if !reason.is_empty() {
|
||||||
info!("GLM stream finished: {}", reason);
|
info!("[GLM] Stream finished: {}", reason);
|
||||||
std::mem::drop(tx.send(String::new()));
|
let _ = tx.send(String::new()).await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -345,13 +365,12 @@ impl LLMProvider for GLMClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep unprocessed data in buffer
|
|
||||||
if let Some(last_newline) = data.rfind('\n') {
|
if let Some(last_newline) = data.rfind('\n') {
|
||||||
buffer = buffer[last_newline + 1..].to_vec();
|
buffer = buffer[last_newline + 1..].to_vec();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::mem::drop(tx.send(String::new())); // Signal completion
|
let _ = tx.send(String::new()).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -359,8 +378,7 @@ impl LLMProvider for GLMClient {
|
||||||
&self,
|
&self,
|
||||||
_session_id: &str,
|
_session_id: &str,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
// GLM doesn't have job cancellation
|
info!("[GLM] Cancel requested for session {} (no-op)", _session_id);
|
||||||
info!("GLM cancel requested for session {} (no-op)", _session_id);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Loading…
Add table
Reference in a new issue