feat: add realtime audio service and speech endpoints
This commit is contained in:
parent
e37554087e
commit
09c59ddc6b
8 changed files with 339 additions and 108 deletions
BIN
outputs/audio/20260311_232859_3d0378.wav
Normal file
BIN
outputs/audio/20260311_232859_3d0378.wav
Normal file
Binary file not shown.
|
|
@ -10,9 +10,9 @@ python-json-logger==2.0.7
|
||||||
|
|
||||||
# Generation Libraries
|
# Generation Libraries
|
||||||
diffusers==0.30.3
|
diffusers==0.30.3
|
||||||
torch==2.5.1
|
torch>=2.5.0
|
||||||
torchaudio==2.5.1
|
torchaudio>=2.5.0
|
||||||
torchvision==0.20.1
|
torchvision>=0.20.0
|
||||||
transformers==4.46.0
|
transformers==4.46.0
|
||||||
accelerate==1.1.1
|
accelerate==1.1.1
|
||||||
safetensors==0.4.5
|
safetensors==0.4.5
|
||||||
|
|
@ -22,6 +22,7 @@ Pillow==11.0.0
|
||||||
openai-whisper==20231117
|
openai-whisper==20231117
|
||||||
TTS==0.22.0
|
TTS==0.22.0
|
||||||
scipy==1.14.1
|
scipy==1.14.1
|
||||||
|
librosa==0.10.2.post1
|
||||||
|
|
||||||
# Video Processing
|
# Video Processing
|
||||||
imageio==2.36.0
|
imageio==2.36.0
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from . import image, scoring, speech, video, vision
|
from . import image, scoring, speech, video
|
||||||
|
|
||||||
__all__ = ["image", "video", "speech", "vision", "scoring"]
|
__all__ = ["image", "video", "speech", "scoring"]
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from ....schemas.generation import (
|
||||||
SpeechToTextResponse,
|
SpeechToTextResponse,
|
||||||
)
|
)
|
||||||
from ....services.speech_service import get_speech_service
|
from ....services.speech_service import get_speech_service
|
||||||
|
from ....services.realtime_audio_service import get_realtime_audio_service
|
||||||
from ...dependencies import verify_api_key
|
from ...dependencies import verify_api_key
|
||||||
|
|
||||||
router = APIRouter(prefix="/speech", tags=["Speech"])
|
router = APIRouter(prefix="/speech", tags=["Speech"])
|
||||||
|
|
@ -83,3 +84,25 @@ async def detect_language(
|
||||||
audio_data = await file.read()
|
audio_data = await file.read()
|
||||||
result = await service.detect_language(audio_data)
|
result = await service.detect_language(audio_data)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/realtime")
|
||||||
|
async def realtime_audio(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
api_key: str = Depends(verify_api_key),
|
||||||
|
service=Depends(get_realtime_audio_service),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process audio using real-time speech-to-speech models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: Audio file to process
|
||||||
|
api_key: API key for authentication
|
||||||
|
service: Real-time Audio service instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with transcribed text and execution info
|
||||||
|
"""
|
||||||
|
audio_data = await file.read()
|
||||||
|
result = await service.process_audio(audio_data)
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
@ -20,6 +21,10 @@ class Settings(BaseSettings):
|
||||||
version: str = "2.0.0"
|
version: str = "2.0.0"
|
||||||
api_key: str = "change-me"
|
api_key: str = "change-me"
|
||||||
|
|
||||||
|
# External Providers for Speech (Optional)
|
||||||
|
groq_api_key: Optional[str] = None
|
||||||
|
openai_api_key: Optional[str] = None
|
||||||
|
|
||||||
# Image generation model
|
# Image generation model
|
||||||
image_model_path: str = "./models/stable-diffusion-v1-5"
|
image_model_path: str = "./models/stable-diffusion-v1-5"
|
||||||
image_steps: int = 4
|
image_steps: int = 4
|
||||||
|
|
@ -46,6 +51,9 @@ class Settings(BaseSettings):
|
||||||
# Whisper model for speech-to-text
|
# Whisper model for speech-to-text
|
||||||
whisper_model_path: str = "./models/whisper"
|
whisper_model_path: str = "./models/whisper"
|
||||||
|
|
||||||
|
# Real-time Audio model for speech-to-speech
|
||||||
|
realtime_audio_model_path: str = "./models/realtime_audio"
|
||||||
|
|
||||||
# Device configuration
|
# Device configuration
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from .api.v1.endpoints import image, scoring, speech, video, vision
|
from .api.v1.endpoints import image, scoring, speech, video
|
||||||
from .core.config import settings
|
from .core.config import settings
|
||||||
from .core.logging import get_logger
|
from .core.logging import get_logger
|
||||||
from .services.image_service import get_image_service
|
from .services.image_service import get_image_service
|
||||||
|
|
@ -50,7 +50,6 @@ app.add_middleware(
|
||||||
app.include_router(image.router, prefix=settings.api_v1_prefix)
|
app.include_router(image.router, prefix=settings.api_v1_prefix)
|
||||||
app.include_router(video.router, prefix=settings.api_v1_prefix)
|
app.include_router(video.router, prefix=settings.api_v1_prefix)
|
||||||
app.include_router(speech.router, prefix=settings.api_v1_prefix)
|
app.include_router(speech.router, prefix=settings.api_v1_prefix)
|
||||||
app.include_router(vision.router, prefix=settings.api_v1_prefix)
|
|
||||||
app.include_router(scoring.router, prefix=settings.api_v1_prefix)
|
app.include_router(scoring.router, prefix=settings.api_v1_prefix)
|
||||||
|
|
||||||
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
||||||
|
|
|
||||||
113
src/services/realtime_audio_service.py
Normal file
113
src/services/realtime_audio_service.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("realtime_audio_service")
|
||||||
|
|
||||||
|
class RealtimeAudioService:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
self.device = settings.device
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Loading Real-time Audio model")
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
||||||
|
|
||||||
|
# Default to PersonaPlex but naming is generic
|
||||||
|
model_id = "nvidia/personaplex-7b-v1"
|
||||||
|
model_path = Path(settings.realtime_audio_model_path)
|
||||||
|
|
||||||
|
if model_path.exists():
|
||||||
|
load_path = str(model_path)
|
||||||
|
else:
|
||||||
|
load_path = model_id
|
||||||
|
|
||||||
|
logger.info(f"Loading model from {load_path}")
|
||||||
|
|
||||||
|
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
|
||||||
|
|
||||||
|
self.processor = AutoProcessor.from_pretrained(load_path)
|
||||||
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||||
|
load_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
use_safetensors=True
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Real-time Audio model loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to load Real-time Audio model", error=str(e))
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
async def process_audio(self, audio_data: bytes, conversation_context: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process audio input using real-time S2S model.
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
if self.model is None or self.processor is None:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Real-time Audio model not initialized"
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
import io
|
||||||
|
|
||||||
|
audio_buf = io.BytesIO(audio_data)
|
||||||
|
y, sr = librosa.load(audio_buf, sr=16000)
|
||||||
|
|
||||||
|
inputs = self.processor(y, sampling_rate=16000, return_tensors="pt").to(self.device)
|
||||||
|
inputs["input_features"] = inputs["input_features"].to(dtype=self.model.dtype)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
inputs["input_features"],
|
||||||
|
max_new_tokens=256,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
logger.info("Processed real-time audio", time=execution_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"text": transcription.strip(),
|
||||||
|
"execution_time": execution_time,
|
||||||
|
"model": "realtime-audio-v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Real-time audio processing failed", error=str(e))
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
_service = None
|
||||||
|
|
||||||
|
def get_realtime_audio_service():
|
||||||
|
global _service
|
||||||
|
if _service is None:
|
||||||
|
_service = RealtimeAudioService()
|
||||||
|
return _service
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
import io
|
import io
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
import urllib.parse
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..core.logging import get_logger
|
from ..core.logging import get_logger
|
||||||
|
|
||||||
|
|
@ -21,35 +24,48 @@ class SpeechService:
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
logger.info("Loading speech models")
|
|
||||||
|
# We only need local models if external providers are NOT configured
|
||||||
|
if not settings.groq_api_key and not settings.openai_api_key:
|
||||||
|
logger.info(
|
||||||
|
"External providers not configured, loading local speech models"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# Load TTS model (Coqui TTS)
|
# Load TTS model (Coqui TTS)
|
||||||
self._load_tts_model()
|
self._load_tts_model()
|
||||||
|
|
||||||
# Load Whisper model for speech-to-text
|
# Load Whisper model for speech-to-text
|
||||||
self._load_whisper_model()
|
self._load_whisper_model()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to load local speech models", error=str(e))
|
||||||
|
else:
|
||||||
|
logger.info("External speech providers detected (Groq/OpenAI)")
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("Speech models loaded successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to load speech models", error=str(e))
|
|
||||||
# Don't raise - allow service to run with partial functionality
|
|
||||||
logger.warning("Speech service will have limited functionality")
|
|
||||||
|
|
||||||
def _load_tts_model(self):
|
def _load_tts_model(self):
|
||||||
"""Load TTS model for text-to-speech generation"""
|
"""Load TTS model for text-to-speech generation"""
|
||||||
try:
|
try:
|
||||||
from TTS.api import TTS
|
from TTS.api import TTS
|
||||||
|
|
||||||
# Use a fast, high-quality model
|
# Use a lightweight model for low-RAM systems (4GB)
|
||||||
self.tts_model = TTS(
|
self.tts_model = TTS(
|
||||||
model_name="tts_models/en/ljspeech/tacotron2-DDC",
|
model_name="tts_models/multilingual/multi-dataset/xtts_v2",
|
||||||
progress_bar=False,
|
progress_bar=False,
|
||||||
gpu=(self.device == "cuda"),
|
gpu=(self.device == "cuda"),
|
||||||
)
|
)
|
||||||
logger.info("TTS model loaded")
|
logger.info("Local TTS model loaded (xtts_v2)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("TTS model not available", error=str(e))
|
logger.warning("XTTS failed, trying lighter model", error=str(e))
|
||||||
|
try:
|
||||||
|
self.tts_model = TTS(
|
||||||
|
model_name="tts_models/en/ljspeech/tacotron2-DDC",
|
||||||
|
progress_bar=False,
|
||||||
|
gpu=False,
|
||||||
|
)
|
||||||
|
logger.info("Local TTS model loaded (tacotron2)")
|
||||||
|
except Exception as e2:
|
||||||
|
logger.warning("Local TTS model not available", error=str(e2))
|
||||||
self.tts_model = None
|
self.tts_model = None
|
||||||
|
|
||||||
def _load_whisper_model(self):
|
def _load_whisper_model(self):
|
||||||
|
|
@ -65,9 +81,9 @@ class SpeechService:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.whisper_model = whisper.load_model(model_size)
|
self.whisper_model = whisper.load_model(model_size)
|
||||||
logger.info("Whisper model loaded", model=model_size)
|
logger.info("Local Whisper model loaded", model=model_size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Whisper model not available", error=str(e))
|
logger.warning("Local Whisper model not available", error=str(e))
|
||||||
self.whisper_model = None
|
self.whisper_model = None
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
|
|
@ -85,40 +101,84 @@ class SpeechService:
|
||||||
filename = f"{timestamp}_{hash(prompt) & 0xFFFFFF:06x}.wav"
|
filename = f"{timestamp}_{hash(prompt) & 0xFFFFFF:06x}.wav"
|
||||||
output_path = settings.output_dir / "audio" / filename
|
output_path = settings.output_dir / "audio" / filename
|
||||||
|
|
||||||
|
# Prefer OpenAI/Groq for high quality/speed if configured
|
||||||
|
if settings.openai_api_key:
|
||||||
|
logger.info("Generating speech via OpenAI API")
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.openai.com/v1/audio/speech",
|
||||||
|
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
|
||||||
|
json={
|
||||||
|
"model": "tts-1",
|
||||||
|
"input": prompt,
|
||||||
|
"voice": voice or "alloy",
|
||||||
|
},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
generation_time = time.time() - start
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"file_path": f"/outputs/audio/{filename}",
|
||||||
|
"generation_time": generation_time,
|
||||||
|
"provider": "openai",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"OpenAI speech generation failed, falling back", error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: Google Translate TTS (free, no API key needed)
|
||||||
|
try:
|
||||||
|
logger.info("Generating speech via Google Translate TTS")
|
||||||
|
lang = language or "pt-BR"
|
||||||
|
google_url = f"https://translate.google.com/translate_tts?ie=UTF-8&q={urllib.parse.quote(prompt)}&tl={lang}&client=tw-ob"
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(google_url, timeout=30.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
generation_time = time.time() - start
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"file_path": f"/outputs/audio/{filename}",
|
||||||
|
"generation_time": generation_time,
|
||||||
|
"provider": "google-translate",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Google Translate TTS failed", error=str(e))
|
||||||
|
|
||||||
if self.tts_model is None:
|
if self.tts_model is None:
|
||||||
logger.error("TTS model not available")
|
logger.error("No TTS provider available")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": "TTS model not initialized",
|
"error": "No TTS provider initialized",
|
||||||
"file_path": None,
|
"file_path": None,
|
||||||
"generation_time": time.time() - start,
|
"generation_time": time.time() - start,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info("Generating speech via local model")
|
||||||
"Generating speech",
|
|
||||||
text_length=len(prompt),
|
|
||||||
voice=voice,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate speech
|
|
||||||
self.tts_model.tts_to_file(
|
self.tts_model.tts_to_file(
|
||||||
text=prompt,
|
text=prompt,
|
||||||
file_path=str(output_path),
|
file_path=str(output_path),
|
||||||
)
|
)
|
||||||
|
|
||||||
generation_time = time.time() - start
|
generation_time = time.time() - start
|
||||||
logger.info("Speech generated", file=filename, time=generation_time)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"file_path": f"/outputs/audio/{filename}",
|
"file_path": f"/outputs/audio/{filename}",
|
||||||
"generation_time": generation_time,
|
"generation_time": generation_time,
|
||||||
|
"provider": "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Speech generation failed", error=str(e))
|
logger.error("Local speech generation failed", error=str(e))
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -127,97 +187,124 @@ class SpeechService:
|
||||||
}
|
}
|
||||||
|
|
||||||
async def to_text(self, audio_data: bytes) -> dict:
|
async def to_text(self, audio_data: bytes) -> dict:
|
||||||
"""Convert speech audio to text using Whisper"""
|
"""Convert speech audio to text using Whisper (External or Local)"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.initialize()
|
self.initialize()
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
if self.whisper_model is None:
|
# 1. Try Groq (Ultra-fast Whisper)
|
||||||
logger.error("Whisper model not available")
|
if settings.groq_api_key:
|
||||||
return {
|
logger.info("Transcribing via Groq Cloud")
|
||||||
"text": "",
|
|
||||||
"language": None,
|
|
||||||
"confidence": 0.0,
|
|
||||||
"error": "Whisper model not initialized",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Save audio to temporary file
|
# Save to temp file for Groq API
|
||||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||||
tmp.write(audio_data)
|
tmp.write(audio_data)
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
|
|
||||||
logger.info("Transcribing audio", file_size=len(audio_data))
|
async with httpx.AsyncClient() as client:
|
||||||
|
with open(tmp_path, "rb") as audio_file:
|
||||||
# Transcribe
|
files = {
|
||||||
result = self.whisper_model.transcribe(tmp_path)
|
"file": (
|
||||||
|
os.path.basename(tmp_path),
|
||||||
# Clean up temp file
|
audio_file,
|
||||||
import os
|
"audio/wav",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
data = {"model": "whisper-large-v3-turbo"}
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.groq.com/openai/v1/audio/transcriptions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {settings.groq_api_key}"
|
||||||
|
},
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
os.unlink(tmp_path)
|
os.unlink(tmp_path)
|
||||||
|
return {
|
||||||
|
"text": result["text"].strip(),
|
||||||
|
"language": result.get("language", "auto"),
|
||||||
|
"confidence": 0.99,
|
||||||
|
"provider": "groq",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Groq transcription failed, falling back", error=str(e))
|
||||||
|
|
||||||
transcription_time = time.time() - start
|
# 2. Try OpenAI
|
||||||
logger.info(
|
if settings.openai_api_key:
|
||||||
"Audio transcribed",
|
logger.info("Transcribing via OpenAI API")
|
||||||
text_length=len(result["text"]),
|
try:
|
||||||
language=result.get("language"),
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||||
time=transcription_time,
|
tmp.write(audio_data)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
with open(tmp_path, "rb") as audio_file:
|
||||||
|
files = {
|
||||||
|
"file": (
|
||||||
|
os.path.basename(tmp_path),
|
||||||
|
audio_file,
|
||||||
|
"audio/wav",
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
data = {"model": "whisper-1"}
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.openai.com/v1/audio/transcriptions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {settings.openai_api_key}"
|
||||||
|
},
|
||||||
|
files=files,
|
||||||
|
data=data,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
return {
|
||||||
|
"text": result["text"].strip(),
|
||||||
|
"language": result.get("language", "auto"),
|
||||||
|
"confidence": 0.99,
|
||||||
|
"provider": "openai",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("OpenAI transcription failed, falling back", error=str(e))
|
||||||
|
|
||||||
|
# 3. Fallback to Local Whisper
|
||||||
|
if self.whisper_model is None:
|
||||||
|
return {"text": "", "error": "No STT provider available"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||||
|
tmp.write(audio_data)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
logger.info("Transcribing via local model")
|
||||||
|
result = self.whisper_model.transcribe(tmp_path)
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": result["text"].strip(),
|
"text": result["text"].strip(),
|
||||||
"language": result.get("language", "en"),
|
"language": result.get("language", "auto"),
|
||||||
"confidence": 0.95, # Whisper doesn't provide confidence directly
|
"confidence": 0.95,
|
||||||
|
"provider": "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Speech-to-text failed", error=str(e))
|
return {"text": "", "error": str(e)}
|
||||||
return {
|
|
||||||
"text": "",
|
|
||||||
"language": None,
|
|
||||||
"confidence": 0.0,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def detect_language(self, audio_data: bytes) -> dict:
|
async def detect_language(self, audio_data: bytes) -> dict:
|
||||||
"""Detect the language of spoken audio"""
|
"""Detect language (simplified to reuse to_text if needed)"""
|
||||||
if not self._initialized:
|
# Just use to_text and return the language field
|
||||||
self.initialize()
|
result = await self.to_text(audio_data)
|
||||||
|
|
||||||
if self.whisper_model is None:
|
|
||||||
return {"language": None, "error": "Whisper model not initialized"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
||||||
tmp.write(audio_data)
|
|
||||||
tmp_path = tmp.name
|
|
||||||
|
|
||||||
import whisper
|
|
||||||
|
|
||||||
# Load audio and detect language
|
|
||||||
audio = whisper.load_audio(tmp_path)
|
|
||||||
audio = whisper.pad_or_trim(audio)
|
|
||||||
mel = whisper.log_mel_spectrogram(audio).to(self.whisper_model.device)
|
|
||||||
_, probs = self.whisper_model.detect_language(mel)
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
|
|
||||||
detected_lang = max(probs, key=probs.get)
|
|
||||||
confidence = probs[detected_lang]
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"language": detected_lang,
|
"language": result.get("language"),
|
||||||
"confidence": confidence,
|
"confidence": result.get("confidence", 0.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Language detection failed", error=str(e))
|
|
||||||
return {"language": None, "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
_service = None
|
_service = None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue