engine / refinement /introspection.py
VeuReu's picture
Upload 10 files
31d4d14 verified
raw
history blame
10.5 kB
"""M貌dul per a l'agent d'"introspection".
Implementa:
- Un proc茅s d'entrenament que apr猫n de les correccions HITL comparant
`une_ad` autom脿tic (MoE/Salamandra) amb `une_ad` de la versi贸 HITL.
- Un pas d'introspecci贸 que aplica aquestes regles a un nou SRT utilitzant
GPT-4o-mini.
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
from pathlib import Path
from typing import Iterable, List, Optional, Tuple
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
logger = logging.getLogger(__name__)
# --- Rutes i constants ---
BASE_DIR = Path(__file__).resolve().parent
# Estructura esperada: .../hf_spaces/engine/refinement/introspection.py
# Per tant, la "root" del repo 茅s el pare immediat de "engine".
REPO_ROOT = BASE_DIR.parents[1]
DEMO_DIR = REPO_ROOT / "demo"
DEMO_TEMP_DIR = DEMO_DIR / "temp"
REFINEMENT_TEMP_DIR = BASE_DIR / "temp"
REFINEMENT_TEMP_DIR.mkdir(exist_ok=True, parents=True)
FEW_SHOT_PATH = REFINEMENT_TEMP_DIR / "few_shot_examples.txt"
RULES_PATH = REFINEMENT_TEMP_DIR / "rules.txt"
AUDIODESCRIPTIONS_DB_PATH = DEMO_TEMP_DIR / "audiodescriptions.db"
def _get_llm() -> Optional[ChatOpenAI]:
"""Retorna una inst脿ncia de GPT-4o-mini o None si no hi ha API key."""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY no est谩 configurada; se omite la introspection.")
return None
try:
return ChatOpenAI(model="gpt-4o-mini", temperature=0.0, api_key=api_key)
except Exception as exc: # pragma: no cover - errors de client extern
logger.error("No se pudo inicializar ChatOpenAI para introspection: %s", exc)
return None
# --- Lectura de dades d'entrenament ---
def _iter_une_vs_hitl_pairs() -> Iterable[Tuple[str, str, str]]:
"""Itera sobre (sha1sum, une_ad_auto, une_ad_hitl).
A partir d'ara:
- une_ad_auto: versi贸 autom脿tica (MoE o Salamandra), camp ``une_ad``.
- une_ad_hitl: versi贸 corregida HITL guardada al mateix registre, camp ``ok_une_ad``.
"""
if not AUDIODESCRIPTIONS_DB_PATH.exists():
logger.warning("audiodescriptions.db no encontrado en %s", AUDIODESCRIPTIONS_DB_PATH)
return
conn = sqlite3.connect(str(AUDIODESCRIPTIONS_DB_PATH))
conn.row_factory = sqlite3.Row
try:
cur = conn.cursor()
try:
cur.execute(
"""
SELECT sha1sum, version, une_ad, ok_une_ad
FROM audiodescriptions
WHERE version IN ('MoE', 'Salamandra')
"""
)
except sqlite3.OperationalError:
logger.warning("Tabla audiodescriptions no disponible en %s", AUDIODESCRIPTIONS_DB_PATH)
return
rows = cur.fetchall()
for row in rows:
sha1sum = row["sha1sum"]
une_auto = (row["une_ad"] or "").strip()
une_hitl = (row["ok_une_ad"] or "").strip() if "ok_une_ad" in row.keys() else ""
if not une_auto or not une_hitl:
continue
if une_hitl == une_auto:
# No hi ha difer猫ncies; no aporta informaci贸
continue
yield sha1sum, une_auto, une_hitl
finally:
conn.close()
def _strip_markdown_fences(content: str) -> str:
"""Elimina fences ```...``` alrededor de una respuesta JSON si existen."""
text = content.strip()
if text.startswith("```"):
lines = text.splitlines()
# descartar primera l铆nea con ``` o ```json
lines = lines[1:]
# eliminar el cierre ``` (pueden existir varias l铆neas en blanco finales)
while lines and lines[-1].strip().startswith("```"):
lines.pop()
text = "\n".join(lines).strip()
return text
def _analyze_correction_with_llm(llm: ChatOpenAI, une_auto: str, une_hitl: str) -> Tuple[str, str]:
"""Demana al LLM que descrigui la correcci贸 i extregui una regla general.
Retorna (few_shot_example, rule). Si falla, retorna cadenes buides.
"""
system = SystemMessage(
content=(
"Ets un assistent que analitza correccions d'audiodescripcions UNE-153010. "
"Se't dona una versi贸 autom脿tica i una versi贸 corregida per humans (HITL). "
"La teva tasca 茅s (1) descriure de forma concisa qu猫 s'ha corregit, amb "
"exemples concrets, i (2) proposar una regla general aplicable a futurs SRT. "
"Respon en format JSON amb les claus 'few_shot_example' i 'rule'."
)
)
user_content = {
"une_ad_auto": une_auto,
"une_ad_hitl": une_hitl,
}
msg = HumanMessage(content=json.dumps(user_content, ensure_ascii=False))
try:
resp = llm.invoke([system, msg])
except Exception as exc: # pragma: no cover - errors externs
logger.error("Error llamando al LLM en introspection training: %s", exc)
return "", ""
raw = resp.content if isinstance(resp.content, str) else str(resp.content)
text = _strip_markdown_fences(raw)
try:
data = json.loads(text)
except json.JSONDecodeError:
logger.warning("La respuesta del LLM no es JSON v谩lido: %s", raw[:2000])
return raw.strip(), ""
few = data.get("few_shot_example", "")
# Aceptamos tanto string como objeto; si es objeto, lo "bonificamos" a texto legible
if isinstance(few, dict):
try:
few_shot = json.dumps(few, ensure_ascii=False, indent=2)
except Exception:
few_shot = str(few)
else:
few_shot = str(few)
rule = str(data.get("rule", "")).strip()
return few_shot.strip(), rule
def train_introspection_rules(max_examples: Optional[int] = None) -> None:
"""Entrena regles d'introspecci贸 a partir de les correccions HITL.
- Recorre audiodescriptions.db buscant parelles (MoE/Salamandra, HITL).
- Per a cada parella amb difer猫ncies significatives, demana al LLM:
* Un "few_shot_example" que descrigui la correcci贸.
* Una "rule" generalitzada.
- Afegeix els exemples a ``few_shot_examples.txt`` i les regles 煤niques a
``rules.txt`` dins de ``engine/refinement/temp``.
"""
llm = _get_llm()
if llm is None:
logger.info("Introspection training skipped: no LLM available.")
return
logger.info("Comen莽ant entrenament d'introspection a partir de %s", AUDIODESCRIPTIONS_DB_PATH)
# Carregar regles existents per no duplicar-les
existing_rules: List[str] = []
if RULES_PATH.exists():
try:
existing_rules = [line.strip() for line in RULES_PATH.read_text(encoding="utf-8").splitlines() if line.strip()]
except Exception:
existing_rules = []
seen_rules = set(existing_rules)
n_processed = 0
n_generated = 0
with FEW_SHOT_PATH.open("a", encoding="utf-8") as f_examples, RULES_PATH.open(
"a", encoding="utf-8"
) as f_rules:
for sha1sum, une_auto, une_hitl in _iter_une_vs_hitl_pairs():
if max_examples is not None and n_processed >= max_examples:
break
n_processed += 1
logger.info("Analitzant correcci贸 HITL per sha1sum=%s", sha1sum)
few_shot, rule = _analyze_correction_with_llm(llm, une_auto, une_hitl)
if not few_shot and not rule:
continue
if few_shot:
f_examples.write("# sha1sum=" + sha1sum + "\n")
f_examples.write(few_shot + "\n\n")
if rule and rule not in seen_rules:
seen_rules.add(rule)
f_rules.write(rule + "\n")
n_generated += 1
logger.info(
"Introspection training completat: %d parelles processades, %d entrades generades",
n_processed,
n_generated,
)
def _load_text_file(path: Path) -> str:
if not path.exists():
return ""
try:
return path.read_text(encoding="utf-8")
except Exception:
return ""
def refine_srt_with_introspection(srt_content: str) -> str:
"""Aplica el pas d'introspecci贸 sobre un SRT.
- Llegeix ``few_shot_examples.txt`` i ``rules.txt`` de ``engine/refinement/temp``.
- Demana a GPT-4o-mini que corregeixi el SRT tenint en compte aquests
exemples i regles.
- Si no hi ha LLM o fitxers, retorna el SRT original.
"""
llm = _get_llm()
if llm is None:
return srt_content
few_shots = _load_text_file(FEW_SHOT_PATH)
rules = _load_text_file(RULES_PATH)
if not few_shots and not rules:
# Res a aplicar; no modifiquem el SRT
return srt_content
system_parts: List[str] = [
"Ets un assistent que millora audiodescripcions en format SRT.",
"Tens unes regles d'introspecci贸 derivades de correccions humanes (HITL)",
"i alguns exemples de correccions anteriors (few-shot examples).",
"Has de produir un nou SRT que apliqui aquestes regles i millores,",
"mantenint l'estructura de temps i el format SRT.",
"Retorna 煤nicament el SRT corregit, sense explicacions addicionals.",
]
if rules:
system_parts.append("\nRegles d'introspecci贸 (una per l铆nia):\n" + rules)
if few_shots:
system_parts.append("\nExemples de correccions (few-shot examples):\n" + few_shots)
system_msg = SystemMessage(content="\n".join(system_parts))
user_msg = HumanMessage(
content=(
"A continuaci贸 tens un SRT generat autom脿ticament. "
"Aplica les regles i l'estil observat als exemples per millorar-lo, "
"especialment en aquells aspectes que solen ser corregits pels humans.\n\n"
"SRT original:\n" + srt_content
)
)
try:
resp = llm.invoke([system_msg, user_msg])
except Exception as exc: # pragma: no cover - errors externs
logger.error("Error llamando al LLM en introspection apply: %s", exc)
return srt_content
text = resp.content if isinstance(resp.content, str) else str(resp.content)
return text.strip() or srt_content