|
|
from __future__ import annotations
|
|
|
|
|
|
import json
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
from typing import TypedDict, Dict, Union, List
|
|
|
|
|
|
from langgraph.graph import StateGraph, END
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
|
|
|
|
from .reflection import (
|
|
|
DEMO_TEMP_DIR,
|
|
|
DEMO_DATA_DIR,
|
|
|
TEMP_DIR,
|
|
|
_load_audiodescription_from_db,
|
|
|
_write_casting_csv_from_db,
|
|
|
_write_scenarios_csv_from_db,
|
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class MultiReflectionState(TypedDict):
|
|
|
iteration: int
|
|
|
current_srt_path: str
|
|
|
critic_report: Dict[str, Union[float, str]]
|
|
|
history: List[SystemMessage]
|
|
|
|
|
|
|
|
|
|
|
|
_llm_ma = ChatOpenAI(model="gpt-4o-mini", temperature=0.2)
|
|
|
|
|
|
|
|
|
def _read_text(path: Path) -> str:
|
|
|
try:
|
|
|
return path.read_text(encoding="utf-8")
|
|
|
except Exception:
|
|
|
return ""
|
|
|
|
|
|
|
|
|
def _load_casting_for_sha1(sha1sum: str) -> str:
|
|
|
db_path = DEMO_DATA_DIR / "casting.db"
|
|
|
if not db_path.exists():
|
|
|
return ""
|
|
|
import sqlite3
|
|
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
|
conn.row_factory = sqlite3.Row
|
|
|
try:
|
|
|
cur = conn.cursor()
|
|
|
cur.execute("SELECT name, description FROM casting WHERE sha1sum=?", (sha1sum,))
|
|
|
rows = cur.fetchall()
|
|
|
if not rows:
|
|
|
return ""
|
|
|
data = [dict(r) for r in rows]
|
|
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
def _load_scenarios_for_sha1(sha1sum: str) -> str:
|
|
|
db_path = DEMO_DATA_DIR / "scenarios.db"
|
|
|
if not db_path.exists():
|
|
|
return ""
|
|
|
import sqlite3
|
|
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
|
conn.row_factory = sqlite3.Row
|
|
|
try:
|
|
|
cur = conn.cursor()
|
|
|
cur.execute("SELECT name, description FROM scenarios WHERE sha1sum=?", (sha1sum,))
|
|
|
rows = cur.fetchall()
|
|
|
if not rows:
|
|
|
return ""
|
|
|
data = [dict(r) for r in rows]
|
|
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
def narrator_initial(state: MultiReflectionState) -> MultiReflectionState:
|
|
|
"""Primer pas del narrador: pren l'SRT inicial tal qual.
|
|
|
|
|
|
En aquest pipeline assumim que l'entrada ja 茅s un SRT UNE inicial.
|
|
|
"""
|
|
|
|
|
|
current_path = Path(state["current_srt_path"])
|
|
|
if not current_path.exists():
|
|
|
logger.warning("[reflection_ma] SRT inicial no trobat a %s", current_path)
|
|
|
content = ""
|
|
|
else:
|
|
|
content = _read_text(current_path)
|
|
|
|
|
|
history = state["history"] + [AIMessage(content="Narrador inicial: SRT de partida carregat.")]
|
|
|
return {
|
|
|
"iteration": state["iteration"],
|
|
|
"current_srt_path": str(current_path),
|
|
|
"critic_report": state.get("critic_report", {}),
|
|
|
"history": history,
|
|
|
}
|
|
|
|
|
|
|
|
|
def identity_manager_agent(state: MultiReflectionState, *, sha1sum: str, info_ad: str) -> MultiReflectionState:
|
|
|
"""Agent que revisa identitats/personatges a partir del casting i info_ad."""
|
|
|
|
|
|
srt_path = Path(state["current_srt_path"])
|
|
|
srt_content = _read_text(srt_path)
|
|
|
casting_json = _load_casting_for_sha1(sha1sum)
|
|
|
|
|
|
prompt = (
|
|
|
"Ets un gestor d'identitats per audiodescripcions. Se't proporciona un SRT "
|
|
|
"i informaci贸 de casting (personatges) i un JSON de context (info_ad). "
|
|
|
"La teva tasca 茅s revisar si els noms i rols dels personatges al SRT s贸n "
|
|
|
"coherents amb el casting i el context. Si cal, corregeix els noms/rols "
|
|
|
"perqu猫 siguin consistents. Mant茅n el format SRT i retorna 煤nicament el SRT modificat."
|
|
|
)
|
|
|
|
|
|
content = {
|
|
|
"srt": srt_content,
|
|
|
"casting": json.loads(casting_json) if casting_json else [],
|
|
|
"info_ad": json.loads(info_ad) if info_ad else {},
|
|
|
}
|
|
|
|
|
|
resp = _llm_ma.invoke(
|
|
|
[
|
|
|
SystemMessage(content=prompt),
|
|
|
HumanMessage(content=json.dumps(content, ensure_ascii=False)),
|
|
|
]
|
|
|
)
|
|
|
|
|
|
new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
|
|
|
new_path = TEMP_DIR / "une_ad_ma_identity.srt"
|
|
|
new_path.write_text(new_srt, encoding="utf-8")
|
|
|
|
|
|
history = state["history"] + [AIMessage(content="Identity manager: SRT actualitzat amb identitats coherents.")]
|
|
|
return {
|
|
|
"iteration": state["iteration"],
|
|
|
"current_srt_path": str(new_path),
|
|
|
"critic_report": state.get("critic_report", {}),
|
|
|
"history": history,
|
|
|
}
|
|
|
|
|
|
|
|
|
def background_descriptor_agent(state: MultiReflectionState, *, sha1sum: str) -> MultiReflectionState:
|
|
|
"""Agent que revisa la descripci贸 d'escenaris a partir de scenarios.db."""
|
|
|
|
|
|
srt_path = Path(state["current_srt_path"])
|
|
|
srt_content = _read_text(srt_path)
|
|
|
scenarios_json = _load_scenarios_for_sha1(sha1sum)
|
|
|
|
|
|
prompt = (
|
|
|
"Ets un expert en escenaris per audiodescripcions. Se't proporciona un SRT "
|
|
|
"i una llista d'escenaris amb noms oficials. La teva tasca 茅s revisar les "
|
|
|
"descripcions de llocs al SRT i substituir refer猫ncies gen猫riques per aquests "
|
|
|
"noms quan millorin la claredat, sense afegir informaci贸 inventada. Mant茅n el "
|
|
|
"format SRT i retorna 煤nicament el SRT actualitzat."
|
|
|
)
|
|
|
|
|
|
content = {
|
|
|
"srt": srt_content,
|
|
|
"scenarios": json.loads(scenarios_json) if scenarios_json else [],
|
|
|
}
|
|
|
|
|
|
resp = _llm_ma.invoke(
|
|
|
[
|
|
|
SystemMessage(content=prompt),
|
|
|
HumanMessage(content=json.dumps(content, ensure_ascii=False)),
|
|
|
]
|
|
|
)
|
|
|
|
|
|
new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
|
|
|
new_path = TEMP_DIR / "une_ad_ma_background.srt"
|
|
|
new_path.write_text(new_srt, encoding="utf-8")
|
|
|
|
|
|
history = state["history"] + [AIMessage(content="Background descriptor: SRT actualitzat amb escenaris contextualitzats.")]
|
|
|
return {
|
|
|
"iteration": state["iteration"],
|
|
|
"current_srt_path": str(new_path),
|
|
|
"critic_report": state.get("critic_report", {}),
|
|
|
"history": history,
|
|
|
}
|
|
|
|
|
|
|
|
|
def narrator_refine_agent(state: MultiReflectionState, *, info_ad: str) -> MultiReflectionState:
|
|
|
"""Segon pas del narrador: reescriu el SRT tenint en compte identitats i escenaris."""
|
|
|
|
|
|
srt_path = Path(state["current_srt_path"])
|
|
|
srt_content = _read_text(srt_path)
|
|
|
|
|
|
prompt = (
|
|
|
"Ets un Narrador d'audiodescripci贸 UNE-153010. Has rebut un SRT on ja s'han "
|
|
|
"revisat les identitats dels personatges i els escenaris. La teva tasca 茅s "
|
|
|
"refinar el text d'audiodescripci贸 perqu猫 sigui clar, coherent i ajustat al "
|
|
|
"temps disponible, mantenint el format SRT i sense alterar els di脿legs. "
|
|
|
"Retorna 煤nicament el SRT final."
|
|
|
)
|
|
|
|
|
|
content = {
|
|
|
"srt": srt_content,
|
|
|
"info_ad": json.loads(info_ad) if info_ad else {},
|
|
|
}
|
|
|
|
|
|
resp = _llm_ma.invoke(
|
|
|
[
|
|
|
SystemMessage(content=prompt),
|
|
|
HumanMessage(content=json.dumps(content, ensure_ascii=False)),
|
|
|
]
|
|
|
)
|
|
|
|
|
|
new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
|
|
|
new_path = TEMP_DIR / "une_ad_ma_final.srt"
|
|
|
new_path.write_text(new_srt, encoding="utf-8")
|
|
|
|
|
|
history = state["history"] + [AIMessage(content="Narrador: SRT refinat despr茅s de gesti贸 d'identitats i escenaris.")]
|
|
|
return {
|
|
|
"iteration": state["iteration"] + 1,
|
|
|
"current_srt_path": str(new_path),
|
|
|
"critic_report": state.get("critic_report", {}),
|
|
|
"history": history,
|
|
|
}
|
|
|
|
|
|
|
|
|
def critic_agent(state: MultiReflectionState) -> MultiReflectionState:
|
|
|
"""Agent que avalua qualitativament el SRT final.
|
|
|
|
|
|
Per simplicitat, aqu铆 no generem CSV ni mitjanes ponderades; nom茅s un resum.
|
|
|
"""
|
|
|
|
|
|
srt_path = Path(state["current_srt_path"])
|
|
|
srt_content = _read_text(srt_path)
|
|
|
|
|
|
prompt = (
|
|
|
"Ets un cr铆tic d'audiodescripcions UNE-153010. Avalua breument la qualitat "
|
|
|
"del SRT proporcionat en termes de precisi贸 descriptiva, sincronitzaci贸 "
|
|
|
"temporal, claredat i adequaci贸 dels noms de personatges i escenaris. "
|
|
|
"Retorna un text breu en catal脿 amb la teva valoraci贸 general."
|
|
|
)
|
|
|
|
|
|
resp = _llm_ma.invoke(
|
|
|
[
|
|
|
SystemMessage(content=prompt),
|
|
|
HumanMessage(content=srt_content),
|
|
|
]
|
|
|
)
|
|
|
|
|
|
critique = resp.content if isinstance(resp.content, str) else str(resp.content)
|
|
|
report: Dict[str, Union[float, str]] = {
|
|
|
"qualitative_critique": critique,
|
|
|
}
|
|
|
|
|
|
history = state["history"] + [AIMessage(content="Cr铆tic: valoraci贸 final generada.")]
|
|
|
return {
|
|
|
"iteration": state["iteration"],
|
|
|
"current_srt_path": state["current_srt_path"],
|
|
|
"critic_report": report,
|
|
|
"history": history,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
_graph = StateGraph(MultiReflectionState)
|
|
|
_graph.add_node("narrator_initial", narrator_initial)
|
|
|
_graph.add_node("identity_manager", lambda s: identity_manager_agent(s, sha1sum=_graph.sha1sum, info_ad=_graph.info_ad))
|
|
|
_graph.add_node("background_descriptor", lambda s: background_descriptor_agent(s, sha1sum=_graph.sha1sum))
|
|
|
_graph.add_node("narrator_refine", lambda s: narrator_refine_agent(s, info_ad=_graph.info_ad))
|
|
|
_graph.add_node("critic", critic_agent)
|
|
|
|
|
|
_graph.set_entry_point("narrator_initial")
|
|
|
_graph.add_edge("narrator_initial", "identity_manager")
|
|
|
_graph.add_edge("identity_manager", "background_descriptor")
|
|
|
_graph.add_edge("background_descriptor", "narrator_refine")
|
|
|
_graph.add_edge("narrator_refine", "critic")
|
|
|
_graph.add_edge("critic", END)
|
|
|
|
|
|
|
|
|
def _compile_app(sha1sum: str, info_ad: str):
|
|
|
"""Compila una inst脿ncia de l'app de LangGraph amb par脿metres de v铆deo."""
|
|
|
|
|
|
|
|
|
_graph.sha1sum = sha1sum
|
|
|
_graph.info_ad = info_ad
|
|
|
return _graph.compile()
|
|
|
|
|
|
|
|
|
def refine_video_with_reflection_ma(sha1sum: str, version: str) -> str:
|
|
|
"""Refina un v铆deo (sha1sum, version) amb el pipeline multiagent de 4 agents.
|
|
|
|
|
|
- Llegeix une_ad i info_ad de audiodescriptions.db (demo/temp).
|
|
|
- Llegeix casting/scenarios per al mateix sha1sum.
|
|
|
- Executa el pipeline narrator -> identity_manager -> background_descriptor -> narrator -> critic.
|
|
|
- Retorna el SRT final generat.
|
|
|
"""
|
|
|
|
|
|
une_ad, info_ad = _load_audiodescription_from_db(sha1sum, version)
|
|
|
|
|
|
|
|
|
TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
initial_path = TEMP_DIR / "une_ad_ma_0.srt"
|
|
|
initial_path.write_text(une_ad or "", encoding="utf-8")
|
|
|
|
|
|
app = _compile_app(sha1sum, info_ad or "")
|
|
|
initial_state: MultiReflectionState = {
|
|
|
"iteration": 0,
|
|
|
"current_srt_path": str(initial_path),
|
|
|
"critic_report": {},
|
|
|
"history": [],
|
|
|
}
|
|
|
|
|
|
final_state = app.invoke(initial_state)
|
|
|
final_path = Path(final_state["current_srt_path"])
|
|
|
return _read_text(final_path)
|
|
|
|
|
|
|
|
|
def refine_srt_with_reflection_ma(srt_content: str) -> str:
|
|
|
"""Variant simplificada que nom茅s rep un SRT (sense info de BD).
|
|
|
|
|
|
Es limita a fer passar el SRT pel pipeline d'identitat/escenaris sense mirar casting/scenarios/info_ad.
|
|
|
脷til per a proves unit脿ries.
|
|
|
"""
|
|
|
|
|
|
TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
initial_path = TEMP_DIR / "une_ad_ma_0.srt"
|
|
|
initial_path.write_text(srt_content or "", encoding="utf-8")
|
|
|
|
|
|
|
|
|
app = _compile_app(sha1sum="", info_ad="{}")
|
|
|
initial_state: MultiReflectionState = {
|
|
|
"iteration": 0,
|
|
|
"current_srt_path": str(initial_path),
|
|
|
"critic_report": {},
|
|
|
"history": [],
|
|
|
}
|
|
|
|
|
|
final_state = app.invoke(initial_state)
|
|
|
final_path = Path(final_state["current_srt_path"])
|
|
|
return _read_text(final_path)
|
|
|
|