323 lines
11 KiB
Rust
323 lines
11 KiB
Rust
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use crate::multimodal::BotModelsClient;
|
|
use crate::shared::models::UserSession;
|
|
use crate::shared::state::AppState;
|
|
use log::{error, trace};
|
|
use rhai::{Dynamic, Engine};
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
|
|
pub fn register_multimodal_keywords(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
image_keyword(state.clone(), user.clone(), engine);
|
|
video_keyword(state.clone(), user.clone(), engine);
|
|
audio_keyword(state.clone(), user.clone(), engine);
|
|
see_keyword(state.clone(), user.clone(), engine);
|
|
}
|
|
|
|
|
|
|
|
pub fn image_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(&["IMAGE", "$expr$"], false, move |context, inputs| {
|
|
let prompt = context.eval_expression_tree(&inputs[0])?.to_string();
|
|
|
|
trace!("IMAGE keyword: generating image for prompt: {}", prompt);
|
|
|
|
let state_for_thread = Arc::clone(&state_clone);
|
|
let bot_id = user_clone.bot_id;
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
|
.worker_threads(2)
|
|
.enable_all()
|
|
.build();
|
|
|
|
let send_err = if let Ok(rt) = rt {
|
|
let result = rt.block_on(async move {
|
|
execute_image_generation(state_for_thread, bot_id, prompt).await
|
|
});
|
|
tx.send(result).err()
|
|
} else {
|
|
tx.send(Err("Failed to build tokio runtime".into())).err()
|
|
};
|
|
|
|
if send_err.is_some() {
|
|
error!("Failed to send IMAGE result");
|
|
}
|
|
});
|
|
|
|
match rx.recv_timeout(Duration::from_secs(300)) {
|
|
Ok(Ok(result)) => Ok(Dynamic::from(result)),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.to_string().into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
|
Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"Image generation timed out".into(),
|
|
rhai::Position::NONE,
|
|
)))
|
|
}
|
|
Err(e) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
format!("IMAGE thread failed: {}", e).into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
})
|
|
.unwrap();
|
|
}
|
|
|
|
async fn execute_image_generation(
|
|
state: Arc<AppState>,
|
|
bot_id: uuid::Uuid,
|
|
prompt: String,
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
let client = BotModelsClient::from_state(&state, &bot_id);
|
|
|
|
if !client.is_enabled() {
|
|
return Err("BotModels is not enabled. Set botmodels-enabled=true in config.csv".into());
|
|
}
|
|
|
|
client.generate_image(&prompt).await
|
|
}
|
|
|
|
|
|
|
|
pub fn video_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(&["VIDEO", "$expr$"], false, move |context, inputs| {
|
|
let prompt = context.eval_expression_tree(&inputs[0])?.to_string();
|
|
|
|
trace!("VIDEO keyword: generating video for prompt: {}", prompt);
|
|
|
|
let state_for_thread = Arc::clone(&state_clone);
|
|
let bot_id = user_clone.bot_id;
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
|
.worker_threads(2)
|
|
.enable_all()
|
|
.build();
|
|
|
|
let send_err = if let Ok(rt) = rt {
|
|
let result = rt.block_on(async move {
|
|
execute_video_generation(state_for_thread, bot_id, prompt).await
|
|
});
|
|
tx.send(result).err()
|
|
} else {
|
|
tx.send(Err("Failed to build tokio runtime".into())).err()
|
|
};
|
|
|
|
if send_err.is_some() {
|
|
error!("Failed to send VIDEO result");
|
|
}
|
|
});
|
|
|
|
|
|
match rx.recv_timeout(Duration::from_secs(600)) {
|
|
Ok(Ok(result)) => Ok(Dynamic::from(result)),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.to_string().into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
|
Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"Video generation timed out".into(),
|
|
rhai::Position::NONE,
|
|
)))
|
|
}
|
|
Err(e) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
format!("VIDEO thread failed: {}", e).into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
})
|
|
.unwrap();
|
|
}
|
|
|
|
async fn execute_video_generation(
|
|
state: Arc<AppState>,
|
|
bot_id: uuid::Uuid,
|
|
prompt: String,
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
let client = BotModelsClient::from_state(&state, &bot_id);
|
|
|
|
if !client.is_enabled() {
|
|
return Err("BotModels is not enabled. Set botmodels-enabled=true in config.csv".into());
|
|
}
|
|
|
|
client.generate_video(&prompt).await
|
|
}
|
|
|
|
|
|
|
|
pub fn audio_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(&["AUDIO", "$expr$"], false, move |context, inputs| {
|
|
let text = context.eval_expression_tree(&inputs[0])?.to_string();
|
|
|
|
trace!("AUDIO keyword: generating speech for text: {}", text);
|
|
|
|
let state_for_thread = Arc::clone(&state_clone);
|
|
let bot_id = user_clone.bot_id;
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
|
.worker_threads(2)
|
|
.enable_all()
|
|
.build();
|
|
|
|
let send_err = if let Ok(rt) = rt {
|
|
let result = rt.block_on(async move {
|
|
execute_audio_generation(state_for_thread, bot_id, text).await
|
|
});
|
|
tx.send(result).err()
|
|
} else {
|
|
tx.send(Err("Failed to build tokio runtime".into())).err()
|
|
};
|
|
|
|
if send_err.is_some() {
|
|
error!("Failed to send AUDIO result");
|
|
}
|
|
});
|
|
|
|
match rx.recv_timeout(Duration::from_secs(120)) {
|
|
Ok(Ok(result)) => Ok(Dynamic::from(result)),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.to_string().into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
|
Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"Audio generation timed out".into(),
|
|
rhai::Position::NONE,
|
|
)))
|
|
}
|
|
Err(e) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
format!("AUDIO thread failed: {}", e).into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
})
|
|
.unwrap();
|
|
}
|
|
|
|
async fn execute_audio_generation(
|
|
state: Arc<AppState>,
|
|
bot_id: uuid::Uuid,
|
|
text: String,
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
let client = BotModelsClient::from_state(&state, &bot_id);
|
|
|
|
if !client.is_enabled() {
|
|
return Err("BotModels is not enabled. Set botmodels-enabled=true in config.csv".into());
|
|
}
|
|
|
|
client.generate_audio(&text, None, None).await
|
|
}
|
|
|
|
|
|
|
|
pub fn see_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
|
let state_clone = Arc::clone(&state);
|
|
let user_clone = user.clone();
|
|
|
|
engine
|
|
.register_custom_syntax(&["SEE", "$expr$"], false, move |context, inputs| {
|
|
let file_path = context.eval_expression_tree(&inputs[0])?.to_string();
|
|
|
|
trace!("SEE keyword: getting caption for file: {}", file_path);
|
|
|
|
let state_for_thread = Arc::clone(&state_clone);
|
|
let bot_id = user_clone.bot_id;
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
std::thread::spawn(move || {
|
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
|
.worker_threads(2)
|
|
.enable_all()
|
|
.build();
|
|
|
|
let send_err = if let Ok(rt) = rt {
|
|
let result = rt.block_on(async move {
|
|
execute_see_caption(state_for_thread, bot_id, file_path).await
|
|
});
|
|
tx.send(result).err()
|
|
} else {
|
|
tx.send(Err("Failed to build tokio runtime".into())).err()
|
|
};
|
|
|
|
if send_err.is_some() {
|
|
error!("Failed to send SEE result");
|
|
}
|
|
});
|
|
|
|
match rx.recv_timeout(Duration::from_secs(60)) {
|
|
Ok(Ok(result)) => Ok(Dynamic::from(result)),
|
|
Ok(Err(e)) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
e.to_string().into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
|
Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
"Vision/caption timed out".into(),
|
|
rhai::Position::NONE,
|
|
)))
|
|
}
|
|
Err(e) => Err(Box::new(rhai::EvalAltResult::ErrorRuntime(
|
|
format!("SEE thread failed: {}", e).into(),
|
|
rhai::Position::NONE,
|
|
))),
|
|
}
|
|
})
|
|
.unwrap();
|
|
}
|
|
|
|
async fn execute_see_caption(
|
|
state: Arc<AppState>,
|
|
bot_id: uuid::Uuid,
|
|
file_path: String,
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
let client = BotModelsClient::from_state(&state, &bot_id);
|
|
|
|
if !client.is_enabled() {
|
|
return Err("BotModels is not enabled. Set botmodels-enabled=true in config.csv".into());
|
|
}
|
|
|
|
|
|
let lower_path = file_path.to_lowercase();
|
|
if lower_path.ends_with(".mp4")
|
|
|| lower_path.ends_with(".avi")
|
|
|| lower_path.ends_with(".mov")
|
|
|| lower_path.ends_with(".webm")
|
|
|| lower_path.ends_with(".mkv")
|
|
{
|
|
client.describe_video(&file_path).await
|
|
} else {
|
|
client.describe_image(&file_path).await
|
|
}
|
|
}
|