feat: add realtime audio service and speech endpoints
This commit is contained in:
parent
e37554087e
commit
09c59ddc6b
8 changed files with 339 additions and 108 deletions
BIN
outputs/audio/20260311_232859_3d0378.wav
Normal file
BIN
outputs/audio/20260311_232859_3d0378.wav
Normal file
Binary file not shown.
|
|
@ -10,9 +10,9 @@ python-json-logger==2.0.7
|
|||
|
||||
# Generation Libraries
|
||||
diffusers==0.30.3
|
||||
torch==2.5.1
|
||||
torchaudio==2.5.1
|
||||
torchvision==0.20.1
|
||||
torch>=2.5.0
|
||||
torchaudio>=2.5.0
|
||||
torchvision>=0.20.0
|
||||
transformers==4.46.0
|
||||
accelerate==1.1.1
|
||||
safetensors==0.4.5
|
||||
|
|
@ -22,6 +22,7 @@ Pillow==11.0.0
|
|||
openai-whisper==20231117
|
||||
TTS==0.22.0
|
||||
scipy==1.14.1
|
||||
librosa==0.10.2.post1
|
||||
|
||||
# Video Processing
|
||||
imageio==2.36.0
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from . import image, scoring, speech, video, vision
|
||||
from . import image, scoring, speech, video
|
||||
|
||||
__all__ = ["image", "video", "speech", "vision", "scoring"]
|
||||
__all__ = ["image", "video", "speech", "scoring"]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from ....schemas.generation import (
|
|||
SpeechToTextResponse,
|
||||
)
|
||||
from ....services.speech_service import get_speech_service
|
||||
from ....services.realtime_audio_service import get_realtime_audio_service
|
||||
from ...dependencies import verify_api_key
|
||||
|
||||
router = APIRouter(prefix="/speech", tags=["Speech"])
|
||||
|
|
@ -83,3 +84,25 @@ async def detect_language(
|
|||
audio_data = await file.read()
|
||||
result = await service.detect_language(audio_data)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/realtime")
|
||||
async def realtime_audio(
|
||||
file: UploadFile = File(...),
|
||||
api_key: str = Depends(verify_api_key),
|
||||
service=Depends(get_realtime_audio_service),
|
||||
):
|
||||
"""
|
||||
Process audio using real-time speech-to-speech models.
|
||||
|
||||
Args:
|
||||
file: Audio file to process
|
||||
api_key: API key for authentication
|
||||
service: Real-time Audio service instance
|
||||
|
||||
Returns:
|
||||
dict with transcribed text and execution info
|
||||
"""
|
||||
audio_data = await file.read()
|
||||
result = await service.process_audio(audio_data)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
|
@ -20,6 +21,10 @@ class Settings(BaseSettings):
|
|||
version: str = "2.0.0"
|
||||
api_key: str = "change-me"
|
||||
|
||||
# External Providers for Speech (Optional)
|
||||
groq_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
|
||||
# Image generation model
|
||||
image_model_path: str = "./models/stable-diffusion-v1-5"
|
||||
image_steps: int = 4
|
||||
|
|
@ -46,6 +51,9 @@ class Settings(BaseSettings):
|
|||
# Whisper model for speech-to-text
|
||||
whisper_model_path: str = "./models/whisper"
|
||||
|
||||
# Real-time Audio model for speech-to-speech
|
||||
realtime_audio_model_path: str = "./models/realtime_audio"
|
||||
|
||||
# Device configuration
|
||||
device: str = "cuda"
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from .api.v1.endpoints import image, scoring, speech, video, vision
|
||||
from .api.v1.endpoints import image, scoring, speech, video
|
||||
from .core.config import settings
|
||||
from .core.logging import get_logger
|
||||
from .services.image_service import get_image_service
|
||||
|
|
@ -50,7 +50,6 @@ app.add_middleware(
|
|||
app.include_router(image.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(video.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(speech.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(vision.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(scoring.router, prefix=settings.api_v1_prefix)
|
||||
|
||||
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
||||
|
|
|
|||
113
src/services/realtime_audio_service.py
Normal file
113
src/services/realtime_audio_service.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import time
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.logging import get_logger
|
||||
|
||||
logger = get_logger("realtime_audio_service")
|
||||
|
||||
class RealtimeAudioService:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.device = settings.device
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Loading Real-time Audio model")
|
||||
try:
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
||||
|
||||
# Default to PersonaPlex but naming is generic
|
||||
model_id = "nvidia/personaplex-7b-v1"
|
||||
model_path = Path(settings.realtime_audio_model_path)
|
||||
|
||||
if model_path.exists():
|
||||
load_path = str(model_path)
|
||||
else:
|
||||
load_path = model_id
|
||||
|
||||
logger.info(f"Loading model from {load_path}")
|
||||
|
||||
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(load_path)
|
||||
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
load_path,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
use_safetensors=True
|
||||
).to(self.device)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Real-time Audio model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to load Real-time Audio model", error=str(e))
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
async def process_audio(self, audio_data: bytes, conversation_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Process audio input using real-time S2S model.
|
||||
"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if self.model is None or self.processor is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Real-time Audio model not initialized"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import librosa
|
||||
import io
|
||||
|
||||
audio_buf = io.BytesIO(audio_data)
|
||||
y, sr = librosa.load(audio_buf, sr=16000)
|
||||
|
||||
inputs = self.processor(y, sampling_rate=16000, return_tensors="pt").to(self.device)
|
||||
inputs["input_features"] = inputs["input_features"].to(dtype=self.model.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
inputs["input_features"],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info("Processed real-time audio", time=execution_time)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"text": transcription.strip(),
|
||||
"execution_time": execution_time,
|
||||
"model": "realtime-audio-v1"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Real-time audio processing failed", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
_service = None
|
||||
|
||||
def get_realtime_audio_service():
|
||||
global _service
|
||||
if _service is None:
|
||||
_service = RealtimeAudioService()
|
||||
return _service
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
import io
|
||||
import tempfile
|
||||
import time
|
||||
import os
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from ..core.config import settings
|
||||
from ..core.logging import get_logger
|
||||
|
||||
|
|
@ -21,36 +24,49 @@ class SpeechService:
|
|||
def initialize(self):
|
||||
if self._initialized:
|
||||
return
|
||||
logger.info("Loading speech models")
|
||||
try:
|
||||
# Load TTS model (Coqui TTS)
|
||||
self._load_tts_model()
|
||||
|
||||
# Load Whisper model for speech-to-text
|
||||
self._load_whisper_model()
|
||||
# We only need local models if external providers are NOT configured
|
||||
if not settings.groq_api_key and not settings.openai_api_key:
|
||||
logger.info(
|
||||
"External providers not configured, loading local speech models"
|
||||
)
|
||||
try:
|
||||
# Load TTS model (Coqui TTS)
|
||||
self._load_tts_model()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Speech models loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to load speech models", error=str(e))
|
||||
# Don't raise - allow service to run with partial functionality
|
||||
logger.warning("Speech service will have limited functionality")
|
||||
# Load Whisper model for speech-to-text
|
||||
self._load_whisper_model()
|
||||
except Exception as e:
|
||||
logger.error("Failed to load local speech models", error=str(e))
|
||||
else:
|
||||
logger.info("External speech providers detected (Groq/OpenAI)")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _load_tts_model(self):
|
||||
"""Load TTS model for text-to-speech generation"""
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
|
||||
# Use a fast, high-quality model
|
||||
# Use a lightweight model for low-RAM systems (4GB)
|
||||
self.tts_model = TTS(
|
||||
model_name="tts_models/en/ljspeech/tacotron2-DDC",
|
||||
model_name="tts_models/multilingual/multi-dataset/xtts_v2",
|
||||
progress_bar=False,
|
||||
gpu=(self.device == "cuda"),
|
||||
)
|
||||
logger.info("TTS model loaded")
|
||||
logger.info("Local TTS model loaded (xtts_v2)")
|
||||
except Exception as e:
|
||||
logger.warning("TTS model not available", error=str(e))
|
||||
self.tts_model = None
|
||||
logger.warning("XTTS failed, trying lighter model", error=str(e))
|
||||
try:
|
||||
self.tts_model = TTS(
|
||||
model_name="tts_models/en/ljspeech/tacotron2-DDC",
|
||||
progress_bar=False,
|
||||
gpu=False,
|
||||
)
|
||||
logger.info("Local TTS model loaded (tacotron2)")
|
||||
except Exception as e2:
|
||||
logger.warning("Local TTS model not available", error=str(e2))
|
||||
self.tts_model = None
|
||||
|
||||
def _load_whisper_model(self):
|
||||
"""Load Whisper model for speech-to-text"""
|
||||
|
|
@ -65,9 +81,9 @@ class SpeechService:
|
|||
)
|
||||
else:
|
||||
self.whisper_model = whisper.load_model(model_size)
|
||||
logger.info("Whisper model loaded", model=model_size)
|
||||
logger.info("Local Whisper model loaded", model=model_size)
|
||||
except Exception as e:
|
||||
logger.warning("Whisper model not available", error=str(e))
|
||||
logger.warning("Local Whisper model not available", error=str(e))
|
||||
self.whisper_model = None
|
||||
|
||||
async def generate(
|
||||
|
|
@ -85,40 +101,84 @@ class SpeechService:
|
|||
filename = f"{timestamp}_{hash(prompt) & 0xFFFFFF:06x}.wav"
|
||||
output_path = settings.output_dir / "audio" / filename
|
||||
|
||||
# Prefer OpenAI/Groq for high quality/speed if configured
|
||||
if settings.openai_api_key:
|
||||
logger.info("Generating speech via OpenAI API")
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/audio/speech",
|
||||
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
|
||||
json={
|
||||
"model": "tts-1",
|
||||
"input": prompt,
|
||||
"voice": voice or "alloy",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
generation_time = time.time() - start
|
||||
return {
|
||||
"status": "completed",
|
||||
"file_path": f"/outputs/audio/{filename}",
|
||||
"generation_time": generation_time,
|
||||
"provider": "openai",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"OpenAI speech generation failed, falling back", error=str(e)
|
||||
)
|
||||
|
||||
# Fallback: Google Translate TTS (free, no API key needed)
|
||||
try:
|
||||
logger.info("Generating speech via Google Translate TTS")
|
||||
lang = language or "pt-BR"
|
||||
google_url = f"https://translate.google.com/translate_tts?ie=UTF-8&q={urllib.parse.quote(prompt)}&tl={lang}&client=tw-ob"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(google_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
generation_time = time.time() - start
|
||||
return {
|
||||
"status": "completed",
|
||||
"file_path": f"/outputs/audio/{filename}",
|
||||
"generation_time": generation_time,
|
||||
"provider": "google-translate",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("Google Translate TTS failed", error=str(e))
|
||||
|
||||
if self.tts_model is None:
|
||||
logger.error("TTS model not available")
|
||||
logger.error("No TTS provider available")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "TTS model not initialized",
|
||||
"error": "No TTS provider initialized",
|
||||
"file_path": None,
|
||||
"generation_time": time.time() - start,
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Generating speech",
|
||||
text_length=len(prompt),
|
||||
voice=voice,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Generate speech
|
||||
logger.info("Generating speech via local model")
|
||||
self.tts_model.tts_to_file(
|
||||
text=prompt,
|
||||
file_path=str(output_path),
|
||||
)
|
||||
|
||||
generation_time = time.time() - start
|
||||
logger.info("Speech generated", file=filename, time=generation_time)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"file_path": f"/outputs/audio/{filename}",
|
||||
"generation_time": generation_time,
|
||||
"provider": "local",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Speech generation failed", error=str(e))
|
||||
logger.error("Local speech generation failed", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
|
|
@ -127,96 +187,123 @@ class SpeechService:
|
|||
}
|
||||
|
||||
async def to_text(self, audio_data: bytes) -> dict:
|
||||
"""Convert speech audio to text using Whisper"""
|
||||
"""Convert speech audio to text using Whisper (External or Local)"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
start = time.time()
|
||||
|
||||
# 1. Try Groq (Ultra-fast Whisper)
|
||||
if settings.groq_api_key:
|
||||
logger.info("Transcribing via Groq Cloud")
|
||||
try:
|
||||
# Save to temp file for Groq API
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(audio_data)
|
||||
tmp_path = tmp.name
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(tmp_path, "rb") as audio_file:
|
||||
files = {
|
||||
"file": (
|
||||
os.path.basename(tmp_path),
|
||||
audio_file,
|
||||
"audio/wav",
|
||||
)
|
||||
}
|
||||
data = {"model": "whisper-large-v3-turbo"}
|
||||
response = await client.post(
|
||||
"https://api.groq.com/openai/v1/audio/transcriptions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {settings.groq_api_key}"
|
||||
},
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
os.unlink(tmp_path)
|
||||
return {
|
||||
"text": result["text"].strip(),
|
||||
"language": result.get("language", "auto"),
|
||||
"confidence": 0.99,
|
||||
"provider": "groq",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Groq transcription failed, falling back", error=str(e))
|
||||
|
||||
# 2. Try OpenAI
|
||||
if settings.openai_api_key:
|
||||
logger.info("Transcribing via OpenAI API")
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(audio_data)
|
||||
tmp_path = tmp.name
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(tmp_path, "rb") as audio_file:
|
||||
files = {
|
||||
"file": (
|
||||
os.path.basename(tmp_path),
|
||||
audio_file,
|
||||
"audio/wav",
|
||||
)
|
||||
}
|
||||
data = {"model": "whisper-1"}
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/audio/transcriptions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {settings.openai_api_key}"
|
||||
},
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
os.unlink(tmp_path)
|
||||
return {
|
||||
"text": result["text"].strip(),
|
||||
"language": result.get("language", "auto"),
|
||||
"confidence": 0.99,
|
||||
"provider": "openai",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("OpenAI transcription failed, falling back", error=str(e))
|
||||
|
||||
# 3. Fallback to Local Whisper
|
||||
if self.whisper_model is None:
|
||||
logger.error("Whisper model not available")
|
||||
return {
|
||||
"text": "",
|
||||
"language": None,
|
||||
"confidence": 0.0,
|
||||
"error": "Whisper model not initialized",
|
||||
}
|
||||
return {"text": "", "error": "No STT provider available"}
|
||||
|
||||
try:
|
||||
# Save audio to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(audio_data)
|
||||
tmp_path = tmp.name
|
||||
|
||||
logger.info("Transcribing audio", file_size=len(audio_data))
|
||||
|
||||
# Transcribe
|
||||
logger.info("Transcribing via local model")
|
||||
result = self.whisper_model.transcribe(tmp_path)
|
||||
|
||||
# Clean up temp file
|
||||
import os
|
||||
|
||||
os.unlink(tmp_path)
|
||||
|
||||
transcription_time = time.time() - start
|
||||
logger.info(
|
||||
"Audio transcribed",
|
||||
text_length=len(result["text"]),
|
||||
language=result.get("language"),
|
||||
time=transcription_time,
|
||||
)
|
||||
|
||||
return {
|
||||
"text": result["text"].strip(),
|
||||
"language": result.get("language", "en"),
|
||||
"confidence": 0.95, # Whisper doesn't provide confidence directly
|
||||
"language": result.get("language", "auto"),
|
||||
"confidence": 0.95,
|
||||
"provider": "local",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Speech-to-text failed", error=str(e))
|
||||
return {
|
||||
"text": "",
|
||||
"language": None,
|
||||
"confidence": 0.0,
|
||||
"error": str(e),
|
||||
}
|
||||
return {"text": "", "error": str(e)}
|
||||
|
||||
async def detect_language(self, audio_data: bytes) -> dict:
|
||||
"""Detect the language of spoken audio"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if self.whisper_model is None:
|
||||
return {"language": None, "error": "Whisper model not initialized"}
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(audio_data)
|
||||
tmp_path = tmp.name
|
||||
|
||||
import whisper
|
||||
|
||||
# Load audio and detect language
|
||||
audio = whisper.load_audio(tmp_path)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
mel = whisper.log_mel_spectrogram(audio).to(self.whisper_model.device)
|
||||
_, probs = self.whisper_model.detect_language(mel)
|
||||
|
||||
import os
|
||||
|
||||
os.unlink(tmp_path)
|
||||
|
||||
detected_lang = max(probs, key=probs.get)
|
||||
confidence = probs[detected_lang]
|
||||
|
||||
return {
|
||||
"language": detected_lang,
|
||||
"confidence": confidence,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Language detection failed", error=str(e))
|
||||
return {"language": None, "error": str(e)}
|
||||
"""Detect language (simplified to reuse to_text if needed)"""
|
||||
# Just use to_text and return the language field
|
||||
result = await self.to_text(audio_data)
|
||||
return {
|
||||
"language": result.get("language"),
|
||||
"confidence": result.get("confidence", 0.0),
|
||||
}
|
||||
|
||||
|
||||
_service = None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue