engine / api_client.py
VeuReu's picture
Upload 3 files
04a4cfc verified
raw
history blame
12.8 kB
# api_client.py (UI - Space "veureu")
import os
import requests
import base64
import zipfile
import io
from typing import Iterable, Dict, Any
class APIClient:
"""
Cliente para 'engine':
POST /jobs -> {"job_id": "..."}
GET /jobs/{job_id}/status -> {"status": "queued|processing|done|failed", ...}
GET /jobs/{job_id}/result -> JobResult {"book": {...}, "une": {...}, ...}
"""
def __init__(self, base_url: str, use_mock: bool = False, data_dir: str | None = None, token: str | None = None, timeout: int = 180):
self.base_url = base_url.rstrip("/")
# La URL para el servicio TTS es la misma que la base_url para los Spaces de HF
self.tts_url = self.base_url
self.use_mock = use_mock
self.data_dir = data_dir
self.timeout = timeout
self.session = requests.Session()
# Permite inyectar el token del engine via secret/var en el Space UI
token = token or os.getenv("API_SHARED_TOKEN")
if token:
self.session.headers.update({"Authorization": f"Bearer {token}"})
# ---- modo real (engine) ----
def _post_jobs(self, video_path: str, modes: Iterable[str]) -> Dict[str, Any]:
url = f"{self.base_url}/jobs"
files = {"file": (os.path.basename(video_path), open(video_path, "rb"), "application/octet-stream")}
data = {"modes": ",".join(modes)}
r = self.session.post(url, files=files, data=data, timeout=self.timeout)
r.raise_for_status()
return r.json() # {"job_id": ...}
def _get_status(self, job_id: str) -> Dict[str, Any]:
url = f"{self.base_url}/jobs/{job_id}/status"
r = self.session.get(url, timeout=self.timeout)
r.raise_for_status()
return r.json()
def _get_result(self, job_id: str) -> Dict[str, Any]:
url = f"{self.base_url}/jobs/{job_id}/result"
r = self.session.get(url, timeout=self.timeout)
r.raise_for_status()
return r.json() # JobResult (book/une/... según engine)
# ---- API que usa streamlit_app.py ----
def process_video(self, video_path: str, modes: Iterable[str]) -> Dict[str, Any]:
"""Devuelve {"job_id": "..."}"""
if self.use_mock:
return {"job_id": "mock-123"}
return self._post_jobs(video_path, modes)
def get_job(self, job_id: str) -> Dict[str, Any]:
"""
La UI espera algo del estilo:
{"status":"done","results":{"book":{...},"une":{...}}}
Adaptamos la respuesta de /result del engine a ese contrato.
"""
if self.use_mock:
# resultado inmediato de prueba
return {
"status": "done",
"results": {
"book": {"text": "Text d'exemple (book)", "mp3_bytes": b""},
"une": {"srt": "1\n00:00:00,000 --> 00:00:01,000\nExemple UNE\n", "mp3_bytes": b""},
}
}
# Opción 1: si quieres chequear estado primero
st = self._get_status(job_id)
if st.get("status") in {"queued", "processing"}:
return {"status": st.get("status", "queued")}
res = self._get_result(job_id)
# 'res' viene como JobResult del engine: {"book": {...}, "une": {...}, ...}
# La UI consume 'results' con claves "book"/"une"; si tus claves ya son iguales, pasa directo:
results = {}
if "book" in res:
results["book"] = {
"text": res["book"].get("text"),
# si sirves URLs en el engine, podrías mapear "book_mp3_url" a descarga directa;
# la UI actual espera "mp3_bytes" sólo en mock, así que lo dejamos fuera.
}
if "une" in res:
results["une"] = {
"srt": res["une"].get("srt"),
}
# Si res incluye "characters"/"metrics", la UI también los guarda:
for k in ("book", "une"):
if k in res:
if "characters" in res[k]:
results[k]["characters"] = res[k]["characters"]
if "metrics" in res[k]:
results[k]["metrics"] = res[k]["metrics"]
status = "done" if results else st.get("status", "unknown")
return {"status": status, "results": results}
def tts_matxa(self, text: str, voice: str = "central/grau") -> dict:
"""
Llama al space 'tts' para sintetizar audio.
Args:
text (str): Texto a sintetizar.
voice (str): Voz de Matxa a usar (p.ej. 'central/alvocat').
Returns:
dict: {'mp3_data_url': 'data:audio/mpeg;base64,...'}
"""
if not self.tts_url:
raise ValueError("La URL del servei TTS no està configurada (API_TTS_URL)")
url = f"{self.tts_url.rstrip('/')}/tts/text"
data = {
"texto": text,
"voice": voice,
"formato": "mp3"
}
try:
r = requests.post(url, data=data, timeout=self.timeout)
r.raise_for_status()
# Devolver los bytes directamente para que el cliente los pueda concatenar
return {"mp3_bytes": r.content}
except requests.exceptions.RequestException as e:
print(f"Error cridant a TTS: {e}")
# Devolvemos un diccionario con error para que la UI lo muestre
return {"error": str(e)}
def rebuild_video_with_ad(self, video_path: str, srt_path: str) -> dict:
"""
Llama al space 'tts' para reconstruir un vídeo con audiodescripció a partir de un SRT.
El servidor devuelve un ZIP, y de ahí extraemos el MP4 final.
"""
if not self.tts_url:
raise ValueError("La URL del servei TTS no està configurada (API_TTS_URL)")
url = f"{self.tts_url.rstrip('/')}/tts/srt"
try:
files = {
'video': (os.path.basename(video_path), open(video_path, 'rb'), 'video/mp4'),
'srt': (os.path.basename(srt_path), open(srt_path, 'rb'), 'application/x-subrip')
}
data = {"include_final_mp4": 1}
r = requests.post(url, files=files, data=data, timeout=self.timeout * 5)
r.raise_for_status()
# El servidor devuelve un ZIP, lo procesamos en memoria
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
# Buscamos el archivo .mp4 dentro del ZIP
for filename in z.namelist():
if filename.endswith('.mp4'):
video_bytes = z.read(filename)
return {"video_bytes": video_bytes}
# Si no se encuentra el MP4 en el ZIP
return {"error": "No se encontró el archivo de vídeo MP4 en la respuesta del servidor."}
except requests.exceptions.RequestException as e:
print(f"Error cridant a la reconstrucció de vídeo: {e}")
return {"error": str(e)}
except zipfile.BadZipFile:
return {"error": "La respuesta del servidor no fue un archivo ZIP válido."}
def create_initial_casting(self, video_path: str = None, video_bytes: bytes = None, video_name: str = None, epsilon: float = 0.5, min_cluster_size: int = 2) -> dict:
"""
Llama al endpoint del space 'engine' para crear el 'initial casting'.
Envía el vídeo recién importado como archivo y los parámetros de clustering.
Args:
video_path: Path to video file (if reading from disk)
video_bytes: Video file bytes (if already in memory)
video_name: Name for the video file
epsilon: Clustering epsilon parameter
min_cluster_size: Minimum cluster size parameter
"""
url = f"{self.base_url}/create_initial_casting"
try:
# Prepare file data
if video_bytes:
filename = video_name or "video.mp4"
files = {
"video": (filename, video_bytes, "video/mp4"),
}
elif video_path:
with open(video_path, "rb") as f:
files = {
"video": (os.path.basename(video_path), f.read(), "video/mp4"),
}
else:
return {"error": "Either video_path or video_bytes must be provided"}
data = {
"epsilon": str(epsilon),
"min_cluster_size": str(min_cluster_size),
}
r = self.session.post(url, files=files, data=data, timeout=self.timeout * 5)
r.raise_for_status()
return r.json() if r.headers.get("content-type", "").startswith("application/json") else {"ok": True}
except requests.exceptions.RequestException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"Unexpected error: {str(e)}"}
def generate_audio_from_text_file(self, text_content: str, voice: str = "central/grau") -> dict:
"""
Genera un único MP3 a partir de un texto largo, usando el endpoint de SRT.
1. Convierte el texto en un SRT falso.
2. Llama a /tts/srt con el SRT.
3. Extrae el 'ad_master.mp3' del ZIP resultante.
"""
if not self.tts_url:
raise ValueError("La URL del servei TTS no està configurada (API_TTS_URL)")
# 1. Crear un SRT falso en memoria
srt_content = ""
start_time = 0
for i, line in enumerate(text_content.strip().split('\n')):
line = line.strip()
if not line:
continue
# Asignar 5 segundos por línea, un valor simple
end_time = start_time + 5
def format_time(seconds):
h = int(seconds / 3600)
m = int((seconds % 3600) / 60)
s = int(seconds % 60)
ms = int((seconds - int(seconds)) * 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
srt_content += f"{i+1}\n"
srt_content += f"{format_time(start_time)} --> {format_time(end_time)}\n"
srt_content += f"{line}\n\n"
start_time = end_time
if not srt_content:
return {"error": "El texto proporcionado estaba vacío o no se pudo procesar."}
# 2. Llamar al endpoint /tts/srt
url = f"{self.tts_url.rstrip('/')}/tts/srt"
try:
files = {
'srt': ('fake_ad.srt', srt_content, 'application/x-subrip')
}
data = {"voice": voice, "ad_format": "mp3"}
r = requests.post(url, files=files, data=data, timeout=self.timeout * 5)
r.raise_for_status()
# 3. Extraer 'ad_master.mp3' del ZIP
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
for filename in z.namelist():
if filename == 'ad_master.mp3':
mp3_bytes = z.read(filename)
return {"mp3_bytes": mp3_bytes}
return {"error": "No se encontró 'ad_master.mp3' en la respuesta del servidor."}
except requests.exceptions.RequestException as e:
return {"error": f"Error llamando a la API de SRT: {e}"}
except zipfile.BadZipFile:
return {"error": "La respuesta del servidor no fue un archivo ZIP válido."}
def tts_long_text(self, text: str, voice: str = "central/grau") -> dict:
"""
Llama al endpoint '/tts/text_long' para sintetizar un texto largo.
La API se encarga de todo el procesamiento.
"""
if not self.tts_url:
raise ValueError("La URL del servei TTS no està configurada (API_TTS_URL)")
url = f"{self.tts_url.rstrip('/')}/tts/text_long"
data = {
"texto": text,
"voice": voice,
"formato": "mp3"
}
try:
# Usamos un timeout más largo por si el texto es muy extenso
r = requests.post(url, data=data, timeout=self.timeout * 10)
r.raise_for_status()
return {"mp3_bytes": r.content}
except requests.exceptions.RequestException as e:
print(f"Error cridant a TTS per a text llarg: {e}")
return {"error": str(e)}