engine / background_descriptor.py
VeuReu's picture
Upload 17 files
287f01b verified
raw
history blame
5.82 kB
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from vision_tools import (
keyframe_conditional_extraction_ana,
keyframe_every_second,
process_frames,
FaceOfImageEmbedding,
generar_montage,
describe_montage_sequence, # fallback local
)
from llm_router import load_yaml, LLMRouter
def cluster_ocr_sequential(ocr_list: List[Dict[str, Any]], threshold: float = 0.6) -> List[Dict[str, Any]]:
if not ocr_list:
return []
ocr_text = [item.get("ocr") for item in ocr_list if item and isinstance(item.get("ocr"), str)]
if not ocr_text:
return []
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(ocr_text, normalize_embeddings=True)
clusters_repr = []
prev_emb = embeddings[0]
start_time = ocr_list[0]["start"]
for i, emb in enumerate(embeddings[1:], 1):
sim = cosine_similarity([prev_emb], [emb])[0][0]
if sim < threshold:
clusters_repr.append({"index": i - 1, "start_time": start_time})
prev_emb = emb
start_time = ocr_list[i]["start"]
clusters_repr.append({"index": len(embeddings) - 1, "start_time": start_time})
ocr_final = []
for cluster in clusters_repr:
idx = cluster["index"]
if idx < len(ocr_list) and ocr_list[idx].get("ocr"):
it = ocr_list[idx]
ocr_final.append({
"ocr": it.get("ocr"),
"image_path": it.get("image_path"),
"start": cluster["start_time"],
"end": it.get("end"),
"faces": it.get("faces"),
})
return ocr_final
def build_keyframes_and_per_second(
video_path: str,
out_dir: Path,
cfg: Dict[str, Any],
face_collection=None,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], float]:
kf_dir = out_dir / "keyframes"
ps_dir = out_dir / "frames_per_second"
keyframes = keyframe_conditional_extraction_ana(video_path=video_path, output_dir=str(kf_dir))
per_second = keyframe_every_second(video_path=video_path, output_dir=str(ps_dir))
embedder = FaceOfImageEmbedding(deepface_model="Facenet512")
kf_proc = process_frames(frames=keyframes, config=cfg, face_col=face_collection, embedding_model=embedder)
ps_proc = process_frames(frames=per_second, config=cfg, face_col=face_collection, embedding_model=embedder)
ocr_list = [{
"ocr": fr.get("ocr"),
"image_path": fr.get("image_path"),
"start": fr.get("start"),
"end": fr.get("end"),
"faces": fr.get("faces"),
} for fr in ps_proc]
ocr_final = cluster_ocr_sequential(ocr_list, threshold=float(cfg.get("video_processing", {}).get("ocr_clustering", {}).get("similarity_threshold", 0.6)))
kf_mod: List[Dict[str, Any]] = []
idx = 1
for k in kf_proc:
ks, ke = k["start"], k["end"]
inicio = True
sustituido = False
for f in ocr_final:
if f["start"] >= ks and f["end"] <= ke and inicio:
kf_mod.append({
"id": idx,
"start": k["start"],
"end": None,
"image_path": f["image_path"],
"faces": f["faces"],
"ocr": f.get("ocr"),
"description": None,
})
idx += 1
sustituido = True
inicio = False
elif f["start"] >= ks and f["end"] <= ke and not inicio:
kf_mod.append({
"id": idx,
"start": f["start"],
"end": None,
"image_path": f["image_path"],
"faces": f["faces"],
"ocr": f.get("ocr"),
"description": None,
})
idx += 1
if not sustituido:
k2 = dict(k)
k2["id"] = idx
kf_mod.append(k2)
idx += 1
return kf_mod, ps_proc, 0.0
def describe_keyframes_with_llm(
keyframes: List[Dict[str, Any]],
out_dir: Path,
face_identities: Optional[set] = None,
config_path: str | None = None,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
cfg = load_yaml(config_path or "config.yaml")
model_name = (cfg.get("background_descriptor", {}).get("description", {}) or {}).get("model", "salamandra-vision")
frame_paths = [k.get("image_path") for k in keyframes if k.get("image_path")]
montage_dir = out_dir / "montage"
montage_path = None
if frame_paths:
montage_path = generar_montage(frame_paths, montage_dir)
context = {
"informacion": [{k: v for k, v in fr.items() if k in ("start", "end", "ocr", "faces")} for fr in keyframes],
"face_identities": sorted(list(face_identities or set()))
}
try:
router = LLMRouter(cfg)
descs = router.vision_describe(frame_paths, context=context, model=model_name)
except Exception:
descs = describe_montage_sequence(
montage_path=str(montage_path),
n=len(frame_paths),
informacion=keyframes,
face_identities=face_identities or set(),
config_path=config_path or "config.yaml",
)
for i, fr in enumerate(keyframes):
if i < len(descs):
fr["description"] = descs[i]
return keyframes, str(montage_path) if montage_path else None