caarleexx commited on
Commit
4945548
·
verified ·
1 Parent(s): 491db5d

Upload vae_manager.py

Browse files
Files changed (1) hide show
  1. managers/vae_manager.py +90 -0
managers/vae_manager.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vae_manager.py — versão simples (beta 1.0)
2
+ # Responsável por decodificar latentes (B,C,T,H,W) → pixels (B,C,T,H',W') em [0,1].
3
+
4
+ import torch
5
+ import contextlib
6
+ import os
7
+ import subprocess
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ from huggingface_hub import logging
12
+
13
+
14
+ logging.set_verbosity_error()
15
+ logging.set_verbosity_warning()
16
+ logging.set_verbosity_info()
17
+ logging.set_verbosity_debug()
18
+
19
+
20
+
21
+
22
+ DEPS_DIR = Path("/data")
23
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
24
+ if not LTX_VIDEO_REPO_DIR.exists():
25
+ print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Rodando setup...")
26
+ run_setup()
27
+
28
+ def add_deps_to_path():
29
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
30
+ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
31
+ sys.path.insert(0, repo_path)
32
+ print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}")
33
+
34
+ add_deps_to_path()
35
+
36
+
37
+
38
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
39
+
40
+
41
+ class _SimpleVAEManager:
42
+ def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32):
43
+ """
44
+ pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...)
45
+ device: "cuda" ou "cpu" onde a decodificação deve ocorrer
46
+ autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32)
47
+ """
48
+ self.pipeline = pipeline
49
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
50
+ self.autocast_dtype = autocast_dtype
51
+
52
+ def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
53
+ self.pipeline = pipeline
54
+ if device is not None:
55
+ self.device = device
56
+ if autocast_dtype is not None:
57
+ self.autocast_dtype = autocast_dtype
58
+
59
+
60
+
61
+ @torch.no_grad()
62
+ def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
63
+
64
+ # Garante device e dtype conforme runtime
65
+ latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype)
66
+
67
+ # Constrói o vetor de timesteps (um por item no batch B)
68
+ num_items_in_batch = latent_tensor_gpu.shape[0]
69
+ timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype)
70
+
71
+ ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
72
+ with ctx:
73
+ pixels = vae_decode(
74
+ latent_tensor_gpu,
75
+ self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline, # compat
76
+ is_video=True,
77
+ timestep=timestep_tensor,
78
+ vae_per_channel_normalize=True,
79
+ )
80
+
81
+ # Normaliza para [0,1] se vier em [-1,1]
82
+ if pixels.min() < 0:
83
+ pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
84
+ else:
85
+ pixels = pixels.clamp(0, 1)
86
+ return pixels
87
+
88
+
89
+ # Singleton global de uso simples
90
+ vae_manager_singleton = _SimpleVAEManager()