engine / refinement /reflection_ma.py
VeuReu's picture
Upload 10 files
31d4d14 verified
raw
history blame
12.1 kB
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import TypedDict, Dict, Union, List
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from .reflection import (
DEMO_TEMP_DIR,
DEMO_DATA_DIR,
TEMP_DIR,
_load_audiodescription_from_db,
_write_casting_csv_from_db,
_write_scenarios_csv_from_db,
)
logger = logging.getLogger(__name__)
class MultiReflectionState(TypedDict):
iteration: int
current_srt_path: str
critic_report: Dict[str, Union[float, str]]
history: List[SystemMessage]
# LLM espec铆fic per al pipeline multiagent (m茅s econ貌mic)
_llm_ma = ChatOpenAI(model="gpt-4o-mini", temperature=0.2)
def _read_text(path: Path) -> str:
try:
return path.read_text(encoding="utf-8")
except Exception:
return ""
def _load_casting_for_sha1(sha1sum: str) -> str:
db_path = DEMO_DATA_DIR / "casting.db"
if not db_path.exists():
return ""
import sqlite3
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
cur = conn.cursor()
cur.execute("SELECT name, description FROM casting WHERE sha1sum=?", (sha1sum,))
rows = cur.fetchall()
if not rows:
return ""
data = [dict(r) for r in rows]
return json.dumps(data, ensure_ascii=False, indent=2)
finally:
conn.close()
def _load_scenarios_for_sha1(sha1sum: str) -> str:
db_path = DEMO_DATA_DIR / "scenarios.db"
if not db_path.exists():
return ""
import sqlite3
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
cur = conn.cursor()
cur.execute("SELECT name, description FROM scenarios WHERE sha1sum=?", (sha1sum,))
rows = cur.fetchall()
if not rows:
return ""
data = [dict(r) for r in rows]
return json.dumps(data, ensure_ascii=False, indent=2)
finally:
conn.close()
def narrator_initial(state: MultiReflectionState) -> MultiReflectionState:
"""Primer pas del narrador: pren l'SRT inicial tal qual.
En aquest pipeline assumim que l'entrada ja 茅s un SRT UNE inicial.
"""
current_path = Path(state["current_srt_path"])
if not current_path.exists():
logger.warning("[reflection_ma] SRT inicial no trobat a %s", current_path)
content = ""
else:
content = _read_text(current_path)
history = state["history"] + [AIMessage(content="Narrador inicial: SRT de partida carregat.")]
return {
"iteration": state["iteration"],
"current_srt_path": str(current_path),
"critic_report": state.get("critic_report", {}),
"history": history,
}
def identity_manager_agent(state: MultiReflectionState, *, sha1sum: str, info_ad: str) -> MultiReflectionState:
"""Agent que revisa identitats/personatges a partir del casting i info_ad."""
srt_path = Path(state["current_srt_path"])
srt_content = _read_text(srt_path)
casting_json = _load_casting_for_sha1(sha1sum)
prompt = (
"Ets un gestor d'identitats per audiodescripcions. Se't proporciona un SRT "
"i informaci贸 de casting (personatges) i un JSON de context (info_ad). "
"La teva tasca 茅s revisar si els noms i rols dels personatges al SRT s贸n "
"coherents amb el casting i el context. Si cal, corregeix els noms/rols "
"perqu猫 siguin consistents. Mant茅n el format SRT i retorna 煤nicament el SRT modificat."
)
content = {
"srt": srt_content,
"casting": json.loads(casting_json) if casting_json else [],
"info_ad": json.loads(info_ad) if info_ad else {},
}
resp = _llm_ma.invoke(
[
SystemMessage(content=prompt),
HumanMessage(content=json.dumps(content, ensure_ascii=False)),
]
)
new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
new_path = TEMP_DIR / "une_ad_ma_identity.srt"
new_path.write_text(new_srt, encoding="utf-8")
history = state["history"] + [AIMessage(content="Identity manager: SRT actualitzat amb identitats coherents.")]
return {
"iteration": state["iteration"],
"current_srt_path": str(new_path),
"critic_report": state.get("critic_report", {}),
"history": history,
}
def background_descriptor_agent(state: MultiReflectionState, *, sha1sum: str) -> MultiReflectionState:
"""Agent que revisa la descripci贸 d'escenaris a partir de scenarios.db."""
srt_path = Path(state["current_srt_path"])
srt_content = _read_text(srt_path)
scenarios_json = _load_scenarios_for_sha1(sha1sum)
prompt = (
"Ets un expert en escenaris per audiodescripcions. Se't proporciona un SRT "
"i una llista d'escenaris amb noms oficials. La teva tasca 茅s revisar les "
"descripcions de llocs al SRT i substituir refer猫ncies gen猫riques per aquests "
"noms quan millorin la claredat, sense afegir informaci贸 inventada. Mant茅n el "
"format SRT i retorna 煤nicament el SRT actualitzat."
)
content = {
"srt": srt_content,
"scenarios": json.loads(scenarios_json) if scenarios_json else [],
}
resp = _llm_ma.invoke(
[
SystemMessage(content=prompt),
HumanMessage(content=json.dumps(content, ensure_ascii=False)),
]
)
new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
new_path = TEMP_DIR / "une_ad_ma_background.srt"
new_path.write_text(new_srt, encoding="utf-8")
history = state["history"] + [AIMessage(content="Background descriptor: SRT actualitzat amb escenaris contextualitzats.")]
return {
"iteration": state["iteration"],
"current_srt_path": str(new_path),
"critic_report": state.get("critic_report", {}),
"history": history,
}
def narrator_refine_agent(state: MultiReflectionState, *, info_ad: str) -> MultiReflectionState:
"""Segon pas del narrador: reescriu el SRT tenint en compte identitats i escenaris."""
srt_path = Path(state["current_srt_path"])
srt_content = _read_text(srt_path)
prompt = (
"Ets un Narrador d'audiodescripci贸 UNE-153010. Has rebut un SRT on ja s'han "
"revisat les identitats dels personatges i els escenaris. La teva tasca 茅s "
"refinar el text d'audiodescripci贸 perqu猫 sigui clar, coherent i ajustat al "
"temps disponible, mantenint el format SRT i sense alterar els di脿legs. "
"Retorna 煤nicament el SRT final."
)
content = {
"srt": srt_content,
"info_ad": json.loads(info_ad) if info_ad else {},
}
resp = _llm_ma.invoke(
[
SystemMessage(content=prompt),
HumanMessage(content=json.dumps(content, ensure_ascii=False)),
]
)
new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
new_path = TEMP_DIR / "une_ad_ma_final.srt"
new_path.write_text(new_srt, encoding="utf-8")
history = state["history"] + [AIMessage(content="Narrador: SRT refinat despr茅s de gesti贸 d'identitats i escenaris.")]
return {
"iteration": state["iteration"] + 1,
"current_srt_path": str(new_path),
"critic_report": state.get("critic_report", {}),
"history": history,
}
def critic_agent(state: MultiReflectionState) -> MultiReflectionState:
"""Agent que avalua qualitativament el SRT final.
Per simplicitat, aqu铆 no generem CSV ni mitjanes ponderades; nom茅s un resum.
"""
srt_path = Path(state["current_srt_path"])
srt_content = _read_text(srt_path)
prompt = (
"Ets un cr铆tic d'audiodescripcions UNE-153010. Avalua breument la qualitat "
"del SRT proporcionat en termes de precisi贸 descriptiva, sincronitzaci贸 "
"temporal, claredat i adequaci贸 dels noms de personatges i escenaris. "
"Retorna un text breu en catal脿 amb la teva valoraci贸 general."
)
resp = _llm_ma.invoke(
[
SystemMessage(content=prompt),
HumanMessage(content=srt_content),
]
)
critique = resp.content if isinstance(resp.content, str) else str(resp.content)
report: Dict[str, Union[float, str]] = {
"qualitative_critique": critique,
}
history = state["history"] + [AIMessage(content="Cr铆tic: valoraci贸 final generada.")]
return {
"iteration": state["iteration"],
"current_srt_path": state["current_srt_path"],
"critic_report": report,
"history": history,
}
# Construcci贸 del graf
_graph = StateGraph(MultiReflectionState)
_graph.add_node("narrator_initial", narrator_initial)
_graph.add_node("identity_manager", lambda s: identity_manager_agent(s, sha1sum=_graph.sha1sum, info_ad=_graph.info_ad))
_graph.add_node("background_descriptor", lambda s: background_descriptor_agent(s, sha1sum=_graph.sha1sum))
_graph.add_node("narrator_refine", lambda s: narrator_refine_agent(s, info_ad=_graph.info_ad))
_graph.add_node("critic", critic_agent)
_graph.set_entry_point("narrator_initial")
_graph.add_edge("narrator_initial", "identity_manager")
_graph.add_edge("identity_manager", "background_descriptor")
_graph.add_edge("background_descriptor", "narrator_refine")
_graph.add_edge("narrator_refine", "critic")
_graph.add_edge("critic", END)
def _compile_app(sha1sum: str, info_ad: str):
"""Compila una inst脿ncia de l'app de LangGraph amb par脿metres de v铆deo."""
# Guardem par脿metres al propi objecte graf per a les lambdes
_graph.sha1sum = sha1sum # type: ignore[attr-defined]
_graph.info_ad = info_ad # type: ignore[attr-defined]
return _graph.compile()
def refine_video_with_reflection_ma(sha1sum: str, version: str) -> str:
"""Refina un v铆deo (sha1sum, version) amb el pipeline multiagent de 4 agents.
- Llegeix une_ad i info_ad de audiodescriptions.db (demo/temp).
- Llegeix casting/scenarios per al mateix sha1sum.
- Executa el pipeline narrator -> identity_manager -> background_descriptor -> narrator -> critic.
- Retorna el SRT final generat.
"""
une_ad, info_ad = _load_audiodescription_from_db(sha1sum, version)
# Preparar fitxer inicial d'entrada
TEMP_DIR.mkdir(exist_ok=True, parents=True)
initial_path = TEMP_DIR / "une_ad_ma_0.srt"
initial_path.write_text(une_ad or "", encoding="utf-8")
app = _compile_app(sha1sum, info_ad or "")
initial_state: MultiReflectionState = {
"iteration": 0,
"current_srt_path": str(initial_path),
"critic_report": {},
"history": [],
}
final_state = app.invoke(initial_state)
final_path = Path(final_state["current_srt_path"])
return _read_text(final_path)
def refine_srt_with_reflection_ma(srt_content: str) -> str:
"""Variant simplificada que nom茅s rep un SRT (sense info de BD).
Es limita a fer passar el SRT pel pipeline d'identitat/escenaris sense mirar casting/scenarios/info_ad.
脷til per a proves unit脿ries.
"""
TEMP_DIR.mkdir(exist_ok=True, parents=True)
initial_path = TEMP_DIR / "une_ad_ma_0.srt"
initial_path.write_text(srt_content or "", encoding="utf-8")
# En aquest mode "standalone" no tenim sha1sum ni info_ad
app = _compile_app(sha1sum="", info_ad="{}")
initial_state: MultiReflectionState = {
"iteration": 0,
"current_srt_path": str(initial_path),
"critic_report": {},
"history": [],
}
final_state = app.invoke(initial_state)
final_path = Path(final_state["current_srt_path"])
return _read_text(final_path)