generalbots/src/llm/bedrock.rs
Rodrigo Rodriguez (Pragmatismo) 260a13e77d
Some checks failed
BotServer CI / build (push) Has been cancelled
refactor: apply various fixes across botserver
2026-03-10 15:15:21 -03:00

280 lines
11 KiB
Rust

use async_trait::async_trait;
use futures::StreamExt;
use log::{error, info, warn};
use serde_json::Value;
use tokio::sync::mpsc;
use crate::llm::LLMProvider;
#[derive(Debug)]
pub struct BedrockClient {
client: reqwest::Client,
base_url: String,
}
impl BedrockClient {
pub fn new(base_url: String) -> Self {
// Accept three URL formats:
// 1. OpenAI-compatible: .../openai/v1/chat/completions (use as-is)
// 2. Native invoke: .../model/{model-id}/invoke (use as-is, streaming swaps to invoke-with-response-stream)
// 3. Bare domain: https://bedrock-runtime.region.amazonaws.com (auto-append OpenAI path)
let url = if base_url.contains("/openai/") || base_url.contains("/chat/completions") || base_url.contains("/model/") {
base_url
} else {
let trimmed = base_url.trim_end_matches('/');
format!("{}/openai/v1/chat/completions", trimmed)
};
Self {
client: reqwest::Client::new(),
base_url: url,
}
}
/// Check if URL is native Bedrock invoke endpoint (not OpenAI-compatible)
fn is_native_invoke(&self) -> bool {
self.base_url.contains("/model/") && self.base_url.contains("/invoke")
}
/// Get streaming URL: for native invoke, swap /invoke to /invoke-with-response-stream
fn stream_url(&self) -> String {
if self.is_native_invoke() && self.base_url.ends_with("/invoke") {
self.base_url.replace("/invoke", "/invoke-with-response-stream")
} else {
self.base_url.clone()
}
}
/// Build the auth header from the key
fn auth_header(key: &str) -> String {
if key.starts_with("Bearer ") {
key.to_string()
} else {
format!("Bearer {}", key)
}
}
/// Build formatted messages from raw input
fn build_messages(raw_messages: &Value) -> Value {
let mut messages_limited = Vec::new();
if let Some(msg_array) = raw_messages.as_array() {
for msg in msg_array {
messages_limited.push(msg.clone());
}
}
Value::Array(messages_limited)
}
/// Send a streaming request and process the response
async fn do_stream(
&self,
formatted_messages: &Value,
model: &str,
key: &str,
tools: Option<&Vec<Value>>,
tx: &mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let auth_header = Self::auth_header(key);
let mut request_body = serde_json::json!({
"model": model,
"messages": formatted_messages,
"stream": true
});
if let Some(tools_value) = tools {
if !tools_value.is_empty() {
request_body["tools"] = serde_json::json!(tools_value);
info!("Added {} tools to Bedrock request", tools_value.len());
}
}
let url = self.stream_url();
info!("Sending streaming request to Bedrock endpoint: {}", url);
let response = self.client
.post(&url)
.header("Authorization", &auth_header)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Bedrock generate_stream error: {}", error_text);
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
}
let mut stream = response.bytes_stream();
let mut tool_call_buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
if let Ok(text) = std::str::from_utf8(&chunk) {
for line in text.split('\n') {
let line = line.trim();
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
continue;
}
if let Ok(json) = serde_json::from_str::<Value>(data) {
if let Some(choices) = json.get("choices") {
if let Some(first_choice) = choices.get(0) {
if let Some(delta) = first_choice.get("delta") {
if let Some(content) = delta.get("content") {
if let Some(content_str) = content.as_str() {
if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() {
return Ok(());
}
}
}
if let Some(tool_calls) = delta.get("tool_calls") {
if let Some(calls_array) = tool_calls.as_array() {
if let Some(first_call) = calls_array.first() {
if let Some(function) = first_call.get("function") {
if let Some(name) = function.get("name") {
if let Some(name_str) = name.as_str() {
tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str);
}
}
if let Some(args) = function.get("arguments") {
if let Some(args_str) = args.as_str() {
tool_call_buffer.push_str(args_str);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
Err(e) => {
error!("Bedrock stream reading error: {}", e);
break;
}
}
}
if !tool_call_buffer.is_empty() {
tool_call_buffer.push_str("\"}");
let _ = tx.send(format!("`tool_call`: {}", tool_call_buffer)).await;
}
Ok(())
}
}
#[async_trait]
impl LLMProvider for BedrockClient {
async fn generate(
&self,
prompt: &str,
messages: &Value,
model: &str,
key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
messages
} else {
&default_messages
};
let formatted_messages = Self::build_messages(raw_messages);
let auth_header = Self::auth_header(key);
let request_body = serde_json::json!({
"model": model,
"messages": formatted_messages,
"stream": false
});
info!("Sending request to Bedrock endpoint: {}", self.base_url);
let response = self.client
.post(&self.base_url)
.header("Authorization", &auth_header)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Bedrock generate error: {}", error_text);
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
}
let json: Value = response.json().await?;
if let Some(choices) = json.get("choices") {
if let Some(first_choice) = choices.get(0) {
if let Some(message) = first_choice.get("message") {
if let Some(content) = message.get("content") {
if let Some(content_str) = content.as_str() {
return Ok(content_str.to_string());
}
}
}
}
}
Err("Failed to parse response from Bedrock".into())
}
async fn generate_stream(
&self,
prompt: &str,
messages: &Value,
tx: mpsc::Sender<String>,
model: &str,
key: &str,
tools: Option<&Vec<Value>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
messages
} else {
&default_messages
};
let formatted_messages = Self::build_messages(raw_messages);
// Try with tools first
let result = self.do_stream(&formatted_messages, model, key, tools, &tx).await;
if let Err(ref e) = result {
let err_str = e.to_string();
// If error is "Operation not allowed" or validation_error, retry without tools
if (err_str.contains("Operation not allowed") || err_str.contains("validation_error"))
&& tools.is_some()
{
warn!(
"Bedrock model '{}' does not support tools, retrying without tools",
model
);
return self.do_stream(&formatted_messages, model, key, None, &tx).await;
}
}
result
}
async fn cancel_job(&self, _session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}