|
|
"""M貌dul per a l'agent d'"introspection".
|
|
|
|
|
|
Implementa:
|
|
|
|
|
|
- Un proc茅s d'entrenament que apr猫n de les correccions HITL comparant
|
|
|
`une_ad` autom脿tic (MoE/Salamandra) amb `une_ad` de la versi贸 HITL.
|
|
|
- Un pas d'introspecci贸 que aplica aquestes regles a un nou SRT utilitzant
|
|
|
GPT-4o-mini.
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import json
|
|
|
import logging
|
|
|
import os
|
|
|
import sqlite3
|
|
|
from pathlib import Path
|
|
|
from typing import Iterable, List, Optional, Tuple
|
|
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
|
|
|
|
|
|
|
REPO_ROOT = BASE_DIR.parents[1]
|
|
|
DEMO_DIR = REPO_ROOT / "demo"
|
|
|
DEMO_TEMP_DIR = DEMO_DIR / "temp"
|
|
|
|
|
|
REFINEMENT_TEMP_DIR = BASE_DIR / "temp"
|
|
|
REFINEMENT_TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
FEW_SHOT_PATH = REFINEMENT_TEMP_DIR / "few_shot_examples.txt"
|
|
|
RULES_PATH = REFINEMENT_TEMP_DIR / "rules.txt"
|
|
|
|
|
|
AUDIODESCRIPTIONS_DB_PATH = DEMO_TEMP_DIR / "audiodescriptions.db"
|
|
|
|
|
|
|
|
|
def _get_llm() -> Optional[ChatOpenAI]:
|
|
|
"""Retorna una inst脿ncia de GPT-4o-mini o None si no hi ha API key."""
|
|
|
|
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
|
if not api_key:
|
|
|
logger.warning("OPENAI_API_KEY no est谩 configurada; se omite la introspection.")
|
|
|
return None
|
|
|
try:
|
|
|
return ChatOpenAI(model="gpt-4o-mini", temperature=0.0, api_key=api_key)
|
|
|
except Exception as exc:
|
|
|
logger.error("No se pudo inicializar ChatOpenAI para introspection: %s", exc)
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _iter_une_vs_hitl_pairs() -> Iterable[Tuple[str, str, str]]:
|
|
|
"""Itera sobre (sha1sum, une_ad_auto, une_ad_hitl).
|
|
|
|
|
|
A partir d'ara:
|
|
|
- une_ad_auto: versi贸 autom脿tica (MoE o Salamandra), camp ``une_ad``.
|
|
|
- une_ad_hitl: versi贸 corregida HITL guardada al mateix registre, camp ``ok_une_ad``.
|
|
|
"""
|
|
|
|
|
|
if not AUDIODESCRIPTIONS_DB_PATH.exists():
|
|
|
logger.warning("audiodescriptions.db no encontrado en %s", AUDIODESCRIPTIONS_DB_PATH)
|
|
|
return
|
|
|
|
|
|
conn = sqlite3.connect(str(AUDIODESCRIPTIONS_DB_PATH))
|
|
|
conn.row_factory = sqlite3.Row
|
|
|
try:
|
|
|
cur = conn.cursor()
|
|
|
try:
|
|
|
cur.execute(
|
|
|
"""
|
|
|
SELECT sha1sum, version, une_ad, ok_une_ad
|
|
|
FROM audiodescriptions
|
|
|
WHERE version IN ('MoE', 'Salamandra')
|
|
|
"""
|
|
|
)
|
|
|
except sqlite3.OperationalError:
|
|
|
logger.warning("Tabla audiodescriptions no disponible en %s", AUDIODESCRIPTIONS_DB_PATH)
|
|
|
return
|
|
|
|
|
|
rows = cur.fetchall()
|
|
|
for row in rows:
|
|
|
sha1sum = row["sha1sum"]
|
|
|
une_auto = (row["une_ad"] or "").strip()
|
|
|
une_hitl = (row["ok_une_ad"] or "").strip() if "ok_une_ad" in row.keys() else ""
|
|
|
|
|
|
if not une_auto or not une_hitl:
|
|
|
continue
|
|
|
|
|
|
if une_hitl == une_auto:
|
|
|
|
|
|
continue
|
|
|
|
|
|
yield sha1sum, une_auto, une_hitl
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
def _strip_markdown_fences(content: str) -> str:
|
|
|
"""Elimina fences ```...``` alrededor de una respuesta JSON si existen."""
|
|
|
|
|
|
text = content.strip()
|
|
|
if text.startswith("```"):
|
|
|
lines = text.splitlines()
|
|
|
|
|
|
lines = lines[1:]
|
|
|
|
|
|
while lines and lines[-1].strip().startswith("```"):
|
|
|
lines.pop()
|
|
|
text = "\n".join(lines).strip()
|
|
|
return text
|
|
|
|
|
|
|
|
|
def _analyze_correction_with_llm(llm: ChatOpenAI, une_auto: str, une_hitl: str) -> Tuple[str, str]:
|
|
|
"""Demana al LLM que descrigui la correcci贸 i extregui una regla general.
|
|
|
|
|
|
Retorna (few_shot_example, rule). Si falla, retorna cadenes buides.
|
|
|
"""
|
|
|
|
|
|
system = SystemMessage(
|
|
|
content=(
|
|
|
"Ets un assistent que analitza correccions d'audiodescripcions UNE-153010. "
|
|
|
"Se't dona una versi贸 autom脿tica i una versi贸 corregida per humans (HITL). "
|
|
|
"La teva tasca 茅s (1) descriure de forma concisa qu猫 s'ha corregit, amb "
|
|
|
"exemples concrets, i (2) proposar una regla general aplicable a futurs SRT. "
|
|
|
"Respon en format JSON amb les claus 'few_shot_example' i 'rule'."
|
|
|
)
|
|
|
)
|
|
|
|
|
|
user_content = {
|
|
|
"une_ad_auto": une_auto,
|
|
|
"une_ad_hitl": une_hitl,
|
|
|
}
|
|
|
|
|
|
msg = HumanMessage(content=json.dumps(user_content, ensure_ascii=False))
|
|
|
|
|
|
try:
|
|
|
resp = llm.invoke([system, msg])
|
|
|
except Exception as exc:
|
|
|
logger.error("Error llamando al LLM en introspection training: %s", exc)
|
|
|
return "", ""
|
|
|
|
|
|
raw = resp.content if isinstance(resp.content, str) else str(resp.content)
|
|
|
text = _strip_markdown_fences(raw)
|
|
|
try:
|
|
|
data = json.loads(text)
|
|
|
except json.JSONDecodeError:
|
|
|
logger.warning("La respuesta del LLM no es JSON v谩lido: %s", raw[:2000])
|
|
|
return raw.strip(), ""
|
|
|
|
|
|
few = data.get("few_shot_example", "")
|
|
|
|
|
|
if isinstance(few, dict):
|
|
|
try:
|
|
|
few_shot = json.dumps(few, ensure_ascii=False, indent=2)
|
|
|
except Exception:
|
|
|
few_shot = str(few)
|
|
|
else:
|
|
|
few_shot = str(few)
|
|
|
|
|
|
rule = str(data.get("rule", "")).strip()
|
|
|
return few_shot.strip(), rule
|
|
|
|
|
|
|
|
|
def train_introspection_rules(max_examples: Optional[int] = None) -> None:
|
|
|
"""Entrena regles d'introspecci贸 a partir de les correccions HITL.
|
|
|
|
|
|
- Recorre audiodescriptions.db buscant parelles (MoE/Salamandra, HITL).
|
|
|
- Per a cada parella amb difer猫ncies significatives, demana al LLM:
|
|
|
* Un "few_shot_example" que descrigui la correcci贸.
|
|
|
* Una "rule" generalitzada.
|
|
|
- Afegeix els exemples a ``few_shot_examples.txt`` i les regles 煤niques a
|
|
|
``rules.txt`` dins de ``engine/refinement/temp``.
|
|
|
"""
|
|
|
|
|
|
llm = _get_llm()
|
|
|
if llm is None:
|
|
|
logger.info("Introspection training skipped: no LLM available.")
|
|
|
return
|
|
|
|
|
|
logger.info("Comen莽ant entrenament d'introspection a partir de %s", AUDIODESCRIPTIONS_DB_PATH)
|
|
|
|
|
|
|
|
|
existing_rules: List[str] = []
|
|
|
if RULES_PATH.exists():
|
|
|
try:
|
|
|
existing_rules = [line.strip() for line in RULES_PATH.read_text(encoding="utf-8").splitlines() if line.strip()]
|
|
|
except Exception:
|
|
|
existing_rules = []
|
|
|
|
|
|
seen_rules = set(existing_rules)
|
|
|
|
|
|
n_processed = 0
|
|
|
n_generated = 0
|
|
|
|
|
|
with FEW_SHOT_PATH.open("a", encoding="utf-8") as f_examples, RULES_PATH.open(
|
|
|
"a", encoding="utf-8"
|
|
|
) as f_rules:
|
|
|
for sha1sum, une_auto, une_hitl in _iter_une_vs_hitl_pairs():
|
|
|
if max_examples is not None and n_processed >= max_examples:
|
|
|
break
|
|
|
|
|
|
n_processed += 1
|
|
|
logger.info("Analitzant correcci贸 HITL per sha1sum=%s", sha1sum)
|
|
|
|
|
|
few_shot, rule = _analyze_correction_with_llm(llm, une_auto, une_hitl)
|
|
|
if not few_shot and not rule:
|
|
|
continue
|
|
|
|
|
|
if few_shot:
|
|
|
f_examples.write("# sha1sum=" + sha1sum + "\n")
|
|
|
f_examples.write(few_shot + "\n\n")
|
|
|
|
|
|
if rule and rule not in seen_rules:
|
|
|
seen_rules.add(rule)
|
|
|
f_rules.write(rule + "\n")
|
|
|
|
|
|
n_generated += 1
|
|
|
|
|
|
logger.info(
|
|
|
"Introspection training completat: %d parelles processades, %d entrades generades",
|
|
|
n_processed,
|
|
|
n_generated,
|
|
|
)
|
|
|
|
|
|
|
|
|
def _load_text_file(path: Path) -> str:
|
|
|
if not path.exists():
|
|
|
return ""
|
|
|
try:
|
|
|
return path.read_text(encoding="utf-8")
|
|
|
except Exception:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
def refine_srt_with_introspection(srt_content: str) -> str:
|
|
|
"""Aplica el pas d'introspecci贸 sobre un SRT.
|
|
|
|
|
|
- Llegeix ``few_shot_examples.txt`` i ``rules.txt`` de ``engine/refinement/temp``.
|
|
|
- Demana a GPT-4o-mini que corregeixi el SRT tenint en compte aquests
|
|
|
exemples i regles.
|
|
|
- Si no hi ha LLM o fitxers, retorna el SRT original.
|
|
|
"""
|
|
|
|
|
|
llm = _get_llm()
|
|
|
if llm is None:
|
|
|
return srt_content
|
|
|
|
|
|
few_shots = _load_text_file(FEW_SHOT_PATH)
|
|
|
rules = _load_text_file(RULES_PATH)
|
|
|
|
|
|
if not few_shots and not rules:
|
|
|
|
|
|
return srt_content
|
|
|
|
|
|
system_parts: List[str] = [
|
|
|
"Ets un assistent que millora audiodescripcions en format SRT.",
|
|
|
"Tens unes regles d'introspecci贸 derivades de correccions humanes (HITL)",
|
|
|
"i alguns exemples de correccions anteriors (few-shot examples).",
|
|
|
"Has de produir un nou SRT que apliqui aquestes regles i millores,",
|
|
|
"mantenint l'estructura de temps i el format SRT.",
|
|
|
"Retorna 煤nicament el SRT corregit, sense explicacions addicionals.",
|
|
|
]
|
|
|
|
|
|
if rules:
|
|
|
system_parts.append("\nRegles d'introspecci贸 (una per l铆nia):\n" + rules)
|
|
|
|
|
|
if few_shots:
|
|
|
system_parts.append("\nExemples de correccions (few-shot examples):\n" + few_shots)
|
|
|
|
|
|
system_msg = SystemMessage(content="\n".join(system_parts))
|
|
|
|
|
|
user_msg = HumanMessage(
|
|
|
content=(
|
|
|
"A continuaci贸 tens un SRT generat autom脿ticament. "
|
|
|
"Aplica les regles i l'estil observat als exemples per millorar-lo, "
|
|
|
"especialment en aquells aspectes que solen ser corregits pels humans.\n\n"
|
|
|
"SRT original:\n" + srt_content
|
|
|
)
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
resp = llm.invoke([system_msg, user_msg])
|
|
|
except Exception as exc:
|
|
|
logger.error("Error llamando al LLM en introspection apply: %s", exc)
|
|
|
return srt_content
|
|
|
|
|
|
text = resp.content if isinstance(resp.content, str) else str(resp.content)
|
|
|
return text.strip() or srt_content
|
|
|
|