feat: add realtime audio service and speech endpoints

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-03-14 16:35:41 -03:00
parent e37554087e
commit 09c59ddc6b
8 changed files with 339 additions and 108 deletions

Binary file not shown.

View file

@ -10,9 +10,9 @@ python-json-logger==2.0.7
# Generation Libraries # Generation Libraries
diffusers==0.30.3 diffusers==0.30.3
torch==2.5.1 torch>=2.5.0
torchaudio==2.5.1 torchaudio>=2.5.0
torchvision==0.20.1 torchvision>=0.20.0
transformers==4.46.0 transformers==4.46.0
accelerate==1.1.1 accelerate==1.1.1
safetensors==0.4.5 safetensors==0.4.5
@ -22,6 +22,7 @@ Pillow==11.0.0
openai-whisper==20231117 openai-whisper==20231117
TTS==0.22.0 TTS==0.22.0
scipy==1.14.1 scipy==1.14.1
librosa==0.10.2.post1
# Video Processing # Video Processing
imageio==2.36.0 imageio==2.36.0

View file

@ -1,3 +1,3 @@
from . import image, scoring, speech, video, vision from . import image, scoring, speech, video
__all__ = ["image", "video", "speech", "vision", "scoring"] __all__ = ["image", "video", "speech", "scoring"]

View file

@ -6,6 +6,7 @@ from ....schemas.generation import (
SpeechToTextResponse, SpeechToTextResponse,
) )
from ....services.speech_service import get_speech_service from ....services.speech_service import get_speech_service
from ....services.realtime_audio_service import get_realtime_audio_service
from ...dependencies import verify_api_key from ...dependencies import verify_api_key
router = APIRouter(prefix="/speech", tags=["Speech"]) router = APIRouter(prefix="/speech", tags=["Speech"])
@ -83,3 +84,25 @@ async def detect_language(
audio_data = await file.read() audio_data = await file.read()
result = await service.detect_language(audio_data) result = await service.detect_language(audio_data)
return result return result
@router.post("/realtime")
async def realtime_audio(
file: UploadFile = File(...),
api_key: str = Depends(verify_api_key),
service=Depends(get_realtime_audio_service),
):
"""
Process audio using real-time speech-to-speech models.
Args:
file: Audio file to process
api_key: API key for authentication
service: Real-time Audio service instance
Returns:
dict with transcribed text and execution info
"""
audio_data = await file.read()
result = await service.process_audio(audio_data)
return result

View file

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -20,6 +21,10 @@ class Settings(BaseSettings):
version: str = "2.0.0" version: str = "2.0.0"
api_key: str = "change-me" api_key: str = "change-me"
# External Providers for Speech (Optional)
groq_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
# Image generation model # Image generation model
image_model_path: str = "./models/stable-diffusion-v1-5" image_model_path: str = "./models/stable-diffusion-v1-5"
image_steps: int = 4 image_steps: int = 4
@ -46,6 +51,9 @@ class Settings(BaseSettings):
# Whisper model for speech-to-text # Whisper model for speech-to-text
whisper_model_path: str = "./models/whisper" whisper_model_path: str = "./models/whisper"
# Real-time Audio model for speech-to-speech
realtime_audio_model_path: str = "./models/realtime_audio"
# Device configuration # Device configuration
device: str = "cuda" device: str = "cuda"

View file

@ -5,7 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from .api.v1.endpoints import image, scoring, speech, video, vision from .api.v1.endpoints import image, scoring, speech, video
from .core.config import settings from .core.config import settings
from .core.logging import get_logger from .core.logging import get_logger
from .services.image_service import get_image_service from .services.image_service import get_image_service
@ -50,7 +50,6 @@ app.add_middleware(
app.include_router(image.router, prefix=settings.api_v1_prefix) app.include_router(image.router, prefix=settings.api_v1_prefix)
app.include_router(video.router, prefix=settings.api_v1_prefix) app.include_router(video.router, prefix=settings.api_v1_prefix)
app.include_router(speech.router, prefix=settings.api_v1_prefix) app.include_router(speech.router, prefix=settings.api_v1_prefix)
app.include_router(vision.router, prefix=settings.api_v1_prefix)
app.include_router(scoring.router, prefix=settings.api_v1_prefix) app.include_router(scoring.router, prefix=settings.api_v1_prefix)
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")

View file

@ -0,0 +1,113 @@
import time
import torch
from pathlib import Path
from typing import Optional, Dict, Any
from ..core.config import settings
from ..core.logging import get_logger
logger = get_logger("realtime_audio_service")
class RealtimeAudioService:
def __init__(self):
self.model = None
self.processor = None
self.device = settings.device
self._initialized = False
def initialize(self):
if self._initialized:
return
logger.info("Loading Real-time Audio model")
try:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
# Default to PersonaPlex but naming is generic
model_id = "nvidia/personaplex-7b-v1"
model_path = Path(settings.realtime_audio_model_path)
if model_path.exists():
load_path = str(model_path)
else:
load_path = model_id
logger.info(f"Loading model from {load_path}")
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.processor = AutoProcessor.from_pretrained(load_path)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
load_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
).to(self.device)
self._initialized = True
logger.info("Real-time Audio model loaded successfully")
except Exception as e:
logger.error("Failed to load Real-time Audio model", error=str(e))
self.model = None
self.processor = None
async def process_audio(self, audio_data: bytes, conversation_context: Optional[str] = None) -> Dict[str, Any]:
"""
Process audio input using real-time S2S model.
"""
if not self._initialized:
self.initialize()
if self.model is None or self.processor is None:
return {
"status": "error",
"error": "Real-time Audio model not initialized"
}
start_time = time.time()
try:
import numpy as np
import librosa
import io
audio_buf = io.BytesIO(audio_data)
y, sr = librosa.load(audio_buf, sr=16000)
inputs = self.processor(y, sampling_rate=16000, return_tensors="pt").to(self.device)
inputs["input_features"] = inputs["input_features"].to(dtype=self.model.dtype)
with torch.no_grad():
generated_ids = self.model.generate(
inputs["input_features"],
max_new_tokens=256,
do_sample=True,
temperature=0.7
)
transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
execution_time = time.time() - start_time
logger.info("Processed real-time audio", time=execution_time)
return {
"status": "completed",
"text": transcription.strip(),
"execution_time": execution_time,
"model": "realtime-audio-v1"
}
except Exception as e:
logger.error("Real-time audio processing failed", error=str(e))
return {
"status": "error",
"error": str(e)
}
_service = None
def get_realtime_audio_service():
global _service
if _service is None:
_service = RealtimeAudioService()
return _service

View file

@ -1,10 +1,13 @@
import io import io
import tempfile import tempfile
import time import time
import os
import urllib.parse
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import httpx
from ..core.config import settings from ..core.config import settings
from ..core.logging import get_logger from ..core.logging import get_logger
@ -21,35 +24,48 @@ class SpeechService:
def initialize(self): def initialize(self):
if self._initialized: if self._initialized:
return return
logger.info("Loading speech models")
# We only need local models if external providers are NOT configured
if not settings.groq_api_key and not settings.openai_api_key:
logger.info(
"External providers not configured, loading local speech models"
)
try: try:
# Load TTS model (Coqui TTS) # Load TTS model (Coqui TTS)
self._load_tts_model() self._load_tts_model()
# Load Whisper model for speech-to-text # Load Whisper model for speech-to-text
self._load_whisper_model() self._load_whisper_model()
except Exception as e:
logger.error("Failed to load local speech models", error=str(e))
else:
logger.info("External speech providers detected (Groq/OpenAI)")
self._initialized = True self._initialized = True
logger.info("Speech models loaded successfully")
except Exception as e:
logger.error("Failed to load speech models", error=str(e))
# Don't raise - allow service to run with partial functionality
logger.warning("Speech service will have limited functionality")
def _load_tts_model(self): def _load_tts_model(self):
"""Load TTS model for text-to-speech generation""" """Load TTS model for text-to-speech generation"""
try: try:
from TTS.api import TTS from TTS.api import TTS
# Use a fast, high-quality model # Use a lightweight model for low-RAM systems (4GB)
self.tts_model = TTS( self.tts_model = TTS(
model_name="tts_models/en/ljspeech/tacotron2-DDC", model_name="tts_models/multilingual/multi-dataset/xtts_v2",
progress_bar=False, progress_bar=False,
gpu=(self.device == "cuda"), gpu=(self.device == "cuda"),
) )
logger.info("TTS model loaded") logger.info("Local TTS model loaded (xtts_v2)")
except Exception as e: except Exception as e:
logger.warning("TTS model not available", error=str(e)) logger.warning("XTTS failed, trying lighter model", error=str(e))
try:
self.tts_model = TTS(
model_name="tts_models/en/ljspeech/tacotron2-DDC",
progress_bar=False,
gpu=False,
)
logger.info("Local TTS model loaded (tacotron2)")
except Exception as e2:
logger.warning("Local TTS model not available", error=str(e2))
self.tts_model = None self.tts_model = None
def _load_whisper_model(self): def _load_whisper_model(self):
@ -65,9 +81,9 @@ class SpeechService:
) )
else: else:
self.whisper_model = whisper.load_model(model_size) self.whisper_model = whisper.load_model(model_size)
logger.info("Whisper model loaded", model=model_size) logger.info("Local Whisper model loaded", model=model_size)
except Exception as e: except Exception as e:
logger.warning("Whisper model not available", error=str(e)) logger.warning("Local Whisper model not available", error=str(e))
self.whisper_model = None self.whisper_model = None
async def generate( async def generate(
@ -85,40 +101,84 @@ class SpeechService:
filename = f"{timestamp}_{hash(prompt) & 0xFFFFFF:06x}.wav" filename = f"{timestamp}_{hash(prompt) & 0xFFFFFF:06x}.wav"
output_path = settings.output_dir / "audio" / filename output_path = settings.output_dir / "audio" / filename
# Prefer OpenAI/Groq for high quality/speed if configured
if settings.openai_api_key:
logger.info("Generating speech via OpenAI API")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
json={
"model": "tts-1",
"input": prompt,
"voice": voice or "alloy",
},
timeout=30.0,
)
response.raise_for_status()
with open(output_path, "wb") as f:
f.write(response.content)
generation_time = time.time() - start
return {
"status": "completed",
"file_path": f"/outputs/audio/{filename}",
"generation_time": generation_time,
"provider": "openai",
}
except Exception as e:
logger.error(
"OpenAI speech generation failed, falling back", error=str(e)
)
# Fallback: Google Translate TTS (free, no API key needed)
try:
logger.info("Generating speech via Google Translate TTS")
lang = language or "pt-BR"
google_url = f"https://translate.google.com/translate_tts?ie=UTF-8&q={urllib.parse.quote(prompt)}&tl={lang}&client=tw-ob"
async with httpx.AsyncClient() as client:
response = await client.get(google_url, timeout=30.0)
response.raise_for_status()
with open(output_path, "wb") as f:
f.write(response.content)
generation_time = time.time() - start
return {
"status": "completed",
"file_path": f"/outputs/audio/{filename}",
"generation_time": generation_time,
"provider": "google-translate",
}
except Exception as e:
logger.warning("Google Translate TTS failed", error=str(e))
if self.tts_model is None: if self.tts_model is None:
logger.error("TTS model not available") logger.error("No TTS provider available")
return { return {
"status": "error", "status": "error",
"error": "TTS model not initialized", "error": "No TTS provider initialized",
"file_path": None, "file_path": None,
"generation_time": time.time() - start, "generation_time": time.time() - start,
} }
try: try:
logger.info( logger.info("Generating speech via local model")
"Generating speech",
text_length=len(prompt),
voice=voice,
language=language,
)
# Generate speech
self.tts_model.tts_to_file( self.tts_model.tts_to_file(
text=prompt, text=prompt,
file_path=str(output_path), file_path=str(output_path),
) )
generation_time = time.time() - start generation_time = time.time() - start
logger.info("Speech generated", file=filename, time=generation_time)
return { return {
"status": "completed", "status": "completed",
"file_path": f"/outputs/audio/{filename}", "file_path": f"/outputs/audio/{filename}",
"generation_time": generation_time, "generation_time": generation_time,
"provider": "local",
} }
except Exception as e: except Exception as e:
logger.error("Speech generation failed", error=str(e)) logger.error("Local speech generation failed", error=str(e))
return { return {
"status": "error", "status": "error",
"error": str(e), "error": str(e),
@ -127,97 +187,124 @@ class SpeechService:
} }
async def to_text(self, audio_data: bytes) -> dict: async def to_text(self, audio_data: bytes) -> dict:
"""Convert speech audio to text using Whisper""" """Convert speech audio to text using Whisper (External or Local)"""
if not self._initialized: if not self._initialized:
self.initialize() self.initialize()
start = time.time() start = time.time()
if self.whisper_model is None: # 1. Try Groq (Ultra-fast Whisper)
logger.error("Whisper model not available") if settings.groq_api_key:
return { logger.info("Transcribing via Groq Cloud")
"text": "",
"language": None,
"confidence": 0.0,
"error": "Whisper model not initialized",
}
try: try:
# Save audio to temporary file # Save to temp file for Groq API
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_data) tmp.write(audio_data)
tmp_path = tmp.name tmp_path = tmp.name
logger.info("Transcribing audio", file_size=len(audio_data)) async with httpx.AsyncClient() as client:
with open(tmp_path, "rb") as audio_file:
# Transcribe files = {
result = self.whisper_model.transcribe(tmp_path) "file": (
os.path.basename(tmp_path),
# Clean up temp file audio_file,
import os "audio/wav",
)
}
data = {"model": "whisper-large-v3-turbo"}
response = await client.post(
"https://api.groq.com/openai/v1/audio/transcriptions",
headers={
"Authorization": f"Bearer {settings.groq_api_key}"
},
files=files,
data=data,
timeout=30.0,
)
response.raise_for_status()
result = response.json()
os.unlink(tmp_path) os.unlink(tmp_path)
return {
"text": result["text"].strip(),
"language": result.get("language", "auto"),
"confidence": 0.99,
"provider": "groq",
}
except Exception as e:
logger.error("Groq transcription failed, falling back", error=str(e))
transcription_time = time.time() - start # 2. Try OpenAI
logger.info( if settings.openai_api_key:
"Audio transcribed", logger.info("Transcribing via OpenAI API")
text_length=len(result["text"]), try:
language=result.get("language"), with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
time=transcription_time, tmp.write(audio_data)
tmp_path = tmp.name
async with httpx.AsyncClient() as client:
with open(tmp_path, "rb") as audio_file:
files = {
"file": (
os.path.basename(tmp_path),
audio_file,
"audio/wav",
) )
}
data = {"model": "whisper-1"}
response = await client.post(
"https://api.openai.com/v1/audio/transcriptions",
headers={
"Authorization": f"Bearer {settings.openai_api_key}"
},
files=files,
data=data,
timeout=30.0,
)
response.raise_for_status()
result = response.json()
os.unlink(tmp_path)
return {
"text": result["text"].strip(),
"language": result.get("language", "auto"),
"confidence": 0.99,
"provider": "openai",
}
except Exception as e:
logger.error("OpenAI transcription failed, falling back", error=str(e))
# 3. Fallback to Local Whisper
if self.whisper_model is None:
return {"text": "", "error": "No STT provider available"}
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_data)
tmp_path = tmp.name
logger.info("Transcribing via local model")
result = self.whisper_model.transcribe(tmp_path)
os.unlink(tmp_path)
return { return {
"text": result["text"].strip(), "text": result["text"].strip(),
"language": result.get("language", "en"), "language": result.get("language", "auto"),
"confidence": 0.95, # Whisper doesn't provide confidence directly "confidence": 0.95,
"provider": "local",
} }
except Exception as e: except Exception as e:
logger.error("Speech-to-text failed", error=str(e)) return {"text": "", "error": str(e)}
return {
"text": "",
"language": None,
"confidence": 0.0,
"error": str(e),
}
async def detect_language(self, audio_data: bytes) -> dict: async def detect_language(self, audio_data: bytes) -> dict:
"""Detect the language of spoken audio""" """Detect language (simplified to reuse to_text if needed)"""
if not self._initialized: # Just use to_text and return the language field
self.initialize() result = await self.to_text(audio_data)
if self.whisper_model is None:
return {"language": None, "error": "Whisper model not initialized"}
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_data)
tmp_path = tmp.name
import whisper
# Load audio and detect language
audio = whisper.load_audio(tmp_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(self.whisper_model.device)
_, probs = self.whisper_model.detect_language(mel)
import os
os.unlink(tmp_path)
detected_lang = max(probs, key=probs.get)
confidence = probs[detected_lang]
return { return {
"language": detected_lang, "language": result.get("language"),
"confidence": confidence, "confidence": result.get("confidence", 0.0),
} }
except Exception as e:
logger.error("Language detection failed", error=str(e))
return {"language": None, "error": str(e)}
_service = None _service = None