eeuuia commited on
Commit
386d75b
·
verified ·
1 Parent(s): 92d7415

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (5).py +301 -0
  2. pipeline_ltx_condition_control (1).py +1506 -0
app (5).py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import tempfile
5
+ import os
6
+
7
+ from diffusers import LTXLatentUpsamplePipeline
8
+ from pipeline_ltx_condition_control import LTXConditionPipeline, LTXVideoCondition
9
+ from diffusers.utils import export_to_video, load_video
10
+ from torchvision import transforms
11
+ import random
12
+ import imageio
13
+ from PIL import Image, ImageOps
14
+ import cv2
15
+ import shutil
16
+ import glob
17
+ from pathlib import Path
18
+
19
+ import warnings
20
+ import logging
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ warnings.filterwarnings("ignore", category=FutureWarning)
23
+ warnings.filterwarnings("ignore", message=".*")
24
+ from huggingface_hub import logging as ll
25
+ ll.set_verbosity_error()
26
+ ll.set_verbosity_warning()
27
+ ll.set_verbosity_info()
28
+ ll.set_verbosity_debug()
29
+ logger = logging.getLogger("AducDebug")
30
+ logging.basicConfig(level=logging.DEBUG)
31
+ logger.setLevel(logging.DEBUG)
32
+
33
+ FPS = 24
34
+ dtype = torch.bfloat16
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ base_model_repo = "Lightricks/LTX-Video"
38
+ print(f"Carregando a arquitetura completa da pipeline de {base_model_repo}...")
39
+ pipeline = LTXConditionPipeline.from_pretrained(
40
+ base_model_repo,
41
+ torch_dtype=dtype,
42
+ cache_dir=os.getenv("HF_HOME_CACHE"),
43
+ token=os.getenv("HF_TOKEN"),
44
+ )
45
+
46
+ # 2. Definir a URL para o arquivo de pesos FP8 que contém apenas o TRANSFORMER.
47
+ fp8_transformer_weights_url = "https://huggingface.co/Lightricks/LTX-Video/ltxv-13b-0.9.8-distilled-fp8.safetensors"
48
+ print(f"Sobrescrevendo pesos do Transformer com o arquivo FP8 de: {fp8_transformer_weights_url}")
49
+
50
+ pipeline.load_lora_weights(fp8_transformer_weights_url, from_diffusers=True)
51
+
52
+ print("Carregando upsampler...")
53
+ pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
54
+ "Lightricks/ltxv-spatial-upscaler-0.9.7",
55
+ cache_dir=os.getenv("HF_HOME_CACHE"),
56
+ vae=pipeline.vae,
57
+ torch_dtype=dtype
58
+ )
59
+
60
+ print("Movendo modelos para o dispositivo...")
61
+ pipeline.to(device)
62
+ pipe_upsample.to(device)
63
+ pipeline.vae.enable_tiling()
64
+
65
+ current_dir = Path(__file__).parent
66
+
67
+ def cleanup_session_files(request: gr.Request):
68
+ """Limpa arquivos temporários da sessão quando o usuário se desconecta."""
69
+ try:
70
+ session_id = request.session_hash
71
+ session_dir = os.path.join("/tmp/gradio", session_id)
72
+ if os.path.exists(session_dir):
73
+ shutil.rmtree(session_dir)
74
+ print(f"Limpou o diretório da sessão: {session_dir}")
75
+ except Exception as e:
76
+ print(f"Erro durante a limpeza da sessão: {e}")
77
+
78
+ def read_video(video) -> torch.Tensor:
79
+ """Lê um arquivo de vídeo e converte para um tensor torch."""
80
+ to_tensor_transform = transforms.ToTensor()
81
+ if isinstance(video, str):
82
+ video_tensor = torch.stack([to_tensor_transform(img) for img in imageio.get_reader(video)])
83
+ else:
84
+ video_tensor = torch.stack([to_tensor_transform(img) for img in video])
85
+ return video_tensor
86
+
87
+
88
+ def round_to_nearest_resolution_acceptable_by_vae(height, width, vae_temporal_compression_ratio):
89
+ """Arredonda a resolução para valores aceitáveis pelo VAE."""
90
+ height = height - (height % vae_temporal_compression_ratio)
91
+ width = width - (width % vae_temporal_compression_ratio)
92
+ return height, width
93
+
94
+
95
+ # A assinatura da função volta a aceitar argumentos individuais para compatibilidade com o Gradio
96
+ def generate_video(
97
+ condition_image_1,
98
+ condition_strength_1,
99
+ condition_frame_index_1,
100
+ condition_image_2,
101
+ condition_strength_2,
102
+ condition_frame_index_2,
103
+ prompt,
104
+ duration=3.0,
105
+ negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
106
+ height=768,
107
+ width=1152,
108
+ num_inference_steps=7,
109
+ guidance_scale=1.0,
110
+ seed=0,
111
+ randomize_seed=False,
112
+ progress=gr.Progress(track_tqdm=True)
113
+ ):
114
+ try:
115
+ # Lógica para agrupar as condições *dentro* da função
116
+ # Cálculo de frames e resolução
117
+ num_frames = int(duration * FPS) + 1
118
+ temporal_compression = pipeline.vae_temporal_compression_ratio
119
+ num_frames = ((num_frames - 1) // temporal_compression) * temporal_compression + 1
120
+
121
+ downscale_factor = 2 / 3
122
+ downscaled_height = int(height * downscale_factor)
123
+ downscaled_width = int(width * downscale_factor)
124
+ downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
125
+ downscaled_height, downscaled_width, pipeline.vae_temporal_compression_ratio
126
+ )
127
+
128
+
129
+
130
+ conditions = []
131
+ if condition_image_1 is not None:
132
+ condition_image_1 = ImageOps.fit(condition_image_1, (downscaled_width, downscaled_height), Image.LANCZOS)
133
+ conditions.append(LTXVideoCondition(
134
+ image=condition_image_1,
135
+ strength=condition_strength_1,
136
+ frame_index=int(condition_frame_index_1)
137
+ ))
138
+ if condition_image_2 is not None:
139
+ condition_image_2 = ImageOps.fit(condition_image_2, (downscaled_width, downscaled_height), Image.LANCZOS)
140
+ conditions.append(LTXVideoCondition(
141
+ image=condition_image_2,
142
+ strength=condition_strength_2,
143
+ frame_index=int(condition_frame_index_2)
144
+ ))
145
+
146
+ pipeline_args = {}
147
+ if conditions:
148
+ pipeline_args["conditions"] = conditions
149
+
150
+ # Manipulação da seed
151
+ if randomize_seed:
152
+ seed = random.randint(0, 2**32 - 1)
153
+
154
+
155
+ # ETAPA 1: Geração do vídeo em baixa resolução
156
+ latents = pipeline(
157
+ prompt=prompt,
158
+ negative_prompt=negative_prompt,
159
+ width=downscaled_width,
160
+ height=downscaled_height,
161
+ num_frames=num_frames,
162
+ timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],
163
+ decode_timestep=0.05,
164
+ decode_noise_scale=0.025,
165
+ image_cond_noise_scale=0.0,
166
+ guidance_scale=guidance_scale,
167
+ guidance_rescale=0.7,
168
+ generator=torch.Generator().manual_seed(seed),
169
+ output_type="latent",
170
+ **pipeline_args
171
+ ).frames
172
+
173
+ # ETAPA 2: Upscale dos latentes
174
+ #upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
175
+ #upscaled_latents = pipe_upsample(
176
+ # latents=latents,
177
+ # output_type="latent"
178
+ #).frames
179
+
180
+ print(f"ETAPA 1 latents {latents.shape}")
181
+
182
+
183
+
184
+ # ETAPA 3: Denoise final em alta resolução
185
+ final_video_frames_np = pipeline(
186
+ prompt=prompt,
187
+ negative_prompt=negative_prompt,
188
+ width=downscaled_width,
189
+ height=downscaled_height,
190
+ num_frames=num_frames,
191
+ denoise_strength=0.999,
192
+ timesteps=[1000, 909, 725, 421, 0],
193
+ latents=latents,
194
+ decode_timestep=0.05,
195
+ decode_noise_scale=0.025,
196
+ image_cond_noise_scale=0.0,
197
+ guidance_scale=guidance_scale,
198
+ guidance_rescale=0.7,
199
+ generator=torch.Generator(device="cuda").manual_seed(seed),
200
+ output_type="np",
201
+ **pipeline_args
202
+ ).frames[0]
203
+
204
+ print(f"ETAPA 3 final_video_frames_np {final_video_frames_np.shape}")
205
+
206
+ # Exportação para arquivo MP4
207
+ video_uint8_frames = [(frame * 255).astype(np.uint8) for frame in final_video_frames_np]
208
+ output_filename = "output.mp4"
209
+ with imageio.get_writer(output_filename, fps=FPS, quality=8, macro_block_size=1) as writer:
210
+ for frame_idx, frame_data in enumerate(video_uint8_frames):
211
+ progress((frame_idx + 1) / len(video_uint8_frames), desc="Codificando frames do vídeo...")
212
+ writer.append_data(frame_data)
213
+
214
+ return output_filename, seed
215
+
216
+ except Exception as e:
217
+ print(f"Ocorreu um erro: {e}")
218
+ return None, seed
219
+
220
+ # Interface Gráfica com Gradio
221
+ with gr.Blocks(theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"]), delete_cache=(60, 900)) as demo:
222
+ gr.Markdown(
223
+ """
224
+ # Geração de Vídeo com LTX
225
+ **Crie vídeos a partir de texto e imagens de condição usando o modelo LTX-Video.**
226
+ """
227
+ )
228
+
229
+ with gr.Row():
230
+ with gr.Column(scale=1):
231
+
232
+ prompt = gr.Textbox(
233
+ label="Prompt",
234
+ placeholder="Descreva o vídeo que você quer gerar...",
235
+ lines=3,
236
+ value="O Coringa em seu icônico terno roxo e cabelo verde, dançando sozinho em um quarto escuro e decadente. Seus movimentos são erráticos e imprevisíveis, alternando entre graciosos e caóticos enquanto ele se perde no momento. A câmera captura seus gestos teatrais, sua dança refletindo sua personalidade desequilibrada. Iluminação temperamental com sombras dançando pelas paredes, criando uma atmosfera de bela loucura."
237
+ )
238
+
239
+ with gr.Accordion("Imagem de Condição 1", open=True):
240
+ condition_image_1 = gr.Image(label="Imagem de Condição 1", type="pil")
241
+ with gr.Row():
242
+ condition_strength_1 = gr.Slider(label="Peso (Strength)", minimum=0.0, maximum=1.0, step=0.05, value=1.0)
243
+ condition_frame_index_1 = gr.Number(label="Frame", value=0, precision=0)
244
+
245
+ with gr.Accordion("Imagem de Condição 2", open=False):
246
+ condition_image_2 = gr.Image(label="Imagem de Condição 2", type="pil")
247
+ with gr.Row():
248
+ condition_strength_2 = gr.Slider(label="Peso (Strength)", minimum=0.0, maximum=1.0, step=0.05, value=1.0)
249
+ condition_frame_index_2 = gr.Number(label="Frame", value=0, precision=0)
250
+
251
+ duration = gr.Slider(label="Duração (segundos)", minimum=1.0, maximum=10.0, step=0.5, value=2)
252
+
253
+ with gr.Accordion("Configurações Avançadas", open=False):
254
+ negative_prompt = gr.Textbox(label="Prompt Negativo", placeholder="O que você não quer no vídeo...", lines=2, value="pior qualidade, movimento inconsistente, embaçado, tremido, distorcido")
255
+ with gr.Row():
256
+ height = gr.Slider(label="Altura", minimum=256, maximum=1536, step=32, value=768)
257
+ width = gr.Slider(label="Largura", minimum=256, maximum=1536, step=32, value=1152)
258
+
259
+ num_inference_steps = gr.Slider(label="Passos de Inferência", minimum=5, maximum=10, step=1, value=7, visible=False)
260
+
261
+ with gr.Row():
262
+ guidance_scale = gr.Slider(label="Escala de Orientação (Guidance)", minimum=1.0, maximum=5.0, step=0.1, value=1.0)
263
+
264
+ with gr.Row():
265
+ randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True)
266
+ seed = gr.Number(label="Seed", value=0, precision=0)
267
+
268
+ generate_btn = gr.Button("Gerar Vídeo", variant="primary", size="lg")
269
+
270
+ with gr.Column(scale=1):
271
+ output_video = gr.Video(label="Vídeo Gerado", height=400)
272
+
273
+ # CORREÇÃO: A lista de inputs agora é "plana", contendo apenas componentes do Gradio
274
+ generate_btn.click(
275
+ fn=generate_video,
276
+ inputs=[
277
+ condition_image_1,
278
+ condition_strength_1,
279
+ condition_frame_index_1,
280
+ condition_image_2,
281
+ condition_strength_2,
282
+ condition_frame_index_2,
283
+ prompt,
284
+ duration,
285
+ negative_prompt,
286
+ height,
287
+ width,
288
+ num_inference_steps,
289
+ guidance_scale,
290
+ seed,
291
+ randomize_seed,
292
+ ],
293
+ outputs=[output_video, seed],
294
+ show_progress=True
295
+ )
296
+
297
+ demo.unload(cleanup_session_files)
298
+
299
+
300
+ if __name__ == "__main__":
301
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True)
pipeline_ltx_condition_control (1).py ADDED
@@ -0,0 +1,1506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import PIL.Image
20
+ import torch
21
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22
+ from diffusers.image_processor import PipelineImageInput
23
+ from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKLLTXVideo
25
+ from diffusers.models.transformers import LTXVideoTransformer3DModel
26
+ from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from transformers import T5EncoderModel, T5TokenizerFast
33
+ from torchvision.transforms.functional import center_crop, resize
34
+
35
+ import warnings
36
+ import logging
37
+ from debug_utils import log_function_io
38
+ warnings.filterwarnings("ignore", category=UserWarning)
39
+ warnings.filterwarnings("ignore", category=FutureWarning)
40
+ warnings.filterwarnings("ignore", message=".*")
41
+ from huggingface_hub import logging as ll
42
+ ll.set_verbosity_error()
43
+ ll.set_verbosity_warning()
44
+ ll.set_verbosity_info()
45
+ ll.set_verbosity_debug()
46
+ logger = logging.getLogger("AducDebug")
47
+ logging.basicConfig(level=logging.DEBUG)
48
+ logger.setLevel(logging.DEBUG)
49
+
50
+ XLA_AVAILABLE= False
51
+
52
+ EXAMPLE_DOC_STRING = """
53
+ Examples:
54
+ ```py
55
+ >>> import torch
56
+ >>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition
57
+ >>> from diffusers.utils import export_to_video, load_video, load_image
58
+
59
+ >>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16)
60
+ >>> pipe.to("cuda")
61
+
62
+ >>> # Load input image and video
63
+ >>> video = load_video(
64
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
65
+ ... )
66
+ >>> image = load_image(
67
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
68
+ ... )
69
+
70
+ >>> # Create conditioning objects
71
+ >>> condition1 = LTXVideoCondition(
72
+ ... image=image,
73
+ ... frame_index=0,
74
+ ... )
75
+ >>> condition2 = LTXVideoCondition(
76
+ ... video=video,
77
+ ... frame_index=80,
78
+ ... )
79
+
80
+ >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
81
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
82
+
83
+ >>> # Generate video
84
+ >>> generator = torch.Generator("cuda").manual_seed(0)
85
+ >>> # Text-only conditioning is also supported without the need to pass `conditions`
86
+ >>> video = pipe(
87
+ ... conditions=[condition1, condition2],
88
+ ... prompt=prompt,
89
+ ... negative_prompt=negative_prompt,
90
+ ... width=768,
91
+ ... height=512,
92
+ ... num_frames=161,
93
+ ... num_inference_steps=40,
94
+ ... generator=generator,
95
+ ... ).frames[0]
96
+
97
+ >>> export_to_video(video, "output.mp4", fps=24)
98
+ ```
99
+ """
100
+
101
+
102
+ @dataclass
103
+ class LTXVideoCondition:
104
+ """
105
+ Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames.
106
+
107
+ Attributes:
108
+ image (`PIL.Image.Image`):
109
+ The image to condition the video on.
110
+ video (`List[PIL.Image.Image]`):
111
+ The video to condition the video on.
112
+ frame_index (`int`):
113
+ The frame index at which the image or video will conditionally effect the video generation.
114
+ strength (`float`, defaults to `1.0`):
115
+ The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied.
116
+ """
117
+
118
+ image: Optional[PIL.Image.Image] = None
119
+ video: Optional[List[PIL.Image.Image]] = None
120
+ frame_index: int = 0
121
+ strength: float = 1.0
122
+
123
+
124
+ # from LTX-Video/ltx_video/schedulers/rf.py
125
+ def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None):
126
+ if linear_steps is None:
127
+ linear_steps = num_steps // 2
128
+ if num_steps < 2:
129
+ return torch.tensor([1.0])
130
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
131
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
132
+ quadratic_steps = num_steps - linear_steps
133
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
134
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
135
+ const = quadratic_coef * (linear_steps**2)
136
+ quadratic_sigma_schedule = [
137
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
138
+ ]
139
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
140
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
141
+ return torch.tensor(sigma_schedule[:-1])
142
+
143
+
144
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
145
+ def calculate_shift(
146
+ image_seq_len,
147
+ base_seq_len: int = 256,
148
+ max_seq_len: int = 4096,
149
+ base_shift: float = 0.5,
150
+ max_shift: float = 1.15,
151
+ ):
152
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
153
+ b = base_shift - m * base_seq_len
154
+ mu = image_seq_len * m + b
155
+ return mu
156
+
157
+
158
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
159
+ def retrieve_timesteps(
160
+ scheduler,
161
+ num_inference_steps: Optional[int] = None,
162
+ device: Optional[Union[str, torch.device]] = None,
163
+ timesteps: Optional[List[int]] = None,
164
+ sigmas: Optional[List[float]] = None,
165
+ **kwargs,
166
+ ):
167
+ r"""
168
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
169
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
170
+
171
+ Args:
172
+ scheduler (`SchedulerMixin`):
173
+ The scheduler to get timesteps from.
174
+ num_inference_steps (`int`):
175
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
176
+ must be `None`.
177
+ device (`str` or `torch.device`, *optional*):
178
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
179
+ timesteps (`List[int]`, *optional*):
180
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
181
+ `num_inference_steps` and `sigmas` must be `None`.
182
+ sigmas (`List[float]`, *optional*):
183
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
184
+ `num_inference_steps` and `timesteps` must be `None`.
185
+
186
+ Returns:
187
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
188
+ second element is the number of inference steps.
189
+ """
190
+ if timesteps is not None and sigmas is not None:
191
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
192
+ if timesteps is not None:
193
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
194
+ if not accepts_timesteps:
195
+ raise ValueError(
196
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
197
+ f" timestep schedules. Please check whether you are using the correct scheduler."
198
+ )
199
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
200
+ timesteps = scheduler.timesteps
201
+ num_inference_steps = len(timesteps)
202
+ elif sigmas is not None:
203
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
204
+ if not accept_sigmas:
205
+ raise ValueError(
206
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
207
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
208
+ )
209
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
210
+ timesteps = scheduler.timesteps
211
+ num_inference_steps = len(timesteps)
212
+ else:
213
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
214
+ timesteps = scheduler.timesteps
215
+ return timesteps, num_inference_steps
216
+
217
+
218
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
219
+ def retrieve_latents(
220
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
221
+ ):
222
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
223
+ return encoder_output.latent_dist.sample(generator)
224
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
225
+ return encoder_output.latent_dist.mode()
226
+ elif hasattr(encoder_output, "latents"):
227
+ return encoder_output.latents
228
+ else:
229
+ raise AttributeError("Could not access latents of provided encoder_output")
230
+
231
+
232
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
233
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
234
+ r"""
235
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
236
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
237
+ Flawed](https://huggingface.co/papers/2305.08891).
238
+
239
+ Args:
240
+ noise_cfg (`torch.Tensor`):
241
+ The predicted noise tensor for the guided diffusion process.
242
+ noise_pred_text (`torch.Tensor`):
243
+ The predicted noise tensor for the text-guided diffusion process.
244
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
245
+ A rescale factor applied to the noise predictions.
246
+
247
+ Returns:
248
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
249
+ """
250
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
251
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
252
+ # rescale the results from guidance (fixes overexposure)
253
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
254
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
255
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
256
+ return noise_cfg
257
+
258
+
259
+ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
260
+ r"""
261
+ Pipeline for text/image/video-to-video generation.
262
+
263
+ Reference: https://github.com/Lightricks/LTX-Video
264
+
265
+ Args:
266
+ transformer ([`LTXVideoTransformer3DModel`]):
267
+ Conditional Transformer architecture to denoise the encoded video latents.
268
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
269
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
270
+ vae ([`AutoencoderKLLTXVideo`]):
271
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
272
+ text_encoder ([`T5EncoderModel`]):
273
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
274
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
275
+ tokenizer (`CLIPTokenizer`):
276
+ Tokenizer of class
277
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
278
+ tokenizer (`T5TokenizerFast`):
279
+ Second Tokenizer of class
280
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
281
+ """
282
+
283
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
284
+ _optional_components = []
285
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
286
+
287
+ def __init__(
288
+ self,
289
+ scheduler: FlowMatchEulerDiscreteScheduler,
290
+ vae: AutoencoderKLLTXVideo,
291
+ text_encoder: T5EncoderModel,
292
+ tokenizer: T5TokenizerFast,
293
+ transformer: LTXVideoTransformer3DModel,
294
+ ):
295
+ super().__init__()
296
+
297
+ self.register_modules(
298
+ vae=vae,
299
+ text_encoder=text_encoder,
300
+ tokenizer=tokenizer,
301
+ transformer=transformer,
302
+ scheduler=scheduler,
303
+ )
304
+
305
+ self.vae_spatial_compression_ratio = (
306
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
307
+ )
308
+ self.vae_temporal_compression_ratio = (
309
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
310
+ )
311
+ self.transformer_spatial_patch_size = (
312
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
313
+ )
314
+ self.transformer_temporal_patch_size = (
315
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
316
+ )
317
+
318
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
319
+ self.tokenizer_max_length = (
320
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
321
+ )
322
+
323
+ self.default_height = 512
324
+ self.default_width = 704
325
+ self.default_frames = 121
326
+
327
+ def _get_t5_prompt_embeds(
328
+ self,
329
+ prompt: Union[str, List[str]] = None,
330
+ num_videos_per_prompt: int = 1,
331
+ max_sequence_length: int = 256,
332
+ device: Optional[torch.device] = None,
333
+ dtype: Optional[torch.dtype] = None,
334
+ ):
335
+ device = device or self._execution_device
336
+ dtype = dtype or self.text_encoder.dtype
337
+
338
+ prompt = [prompt] if isinstance(prompt, str) else prompt
339
+ batch_size = len(prompt)
340
+
341
+ text_inputs = self.tokenizer(
342
+ prompt,
343
+ padding="max_length",
344
+ max_length=max_sequence_length,
345
+ truncation=True,
346
+ add_special_tokens=True,
347
+ return_tensors="pt",
348
+ )
349
+ text_input_ids = text_inputs.input_ids
350
+ prompt_attention_mask = text_inputs.attention_mask
351
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
352
+
353
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
354
+
355
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
356
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
357
+ logger.warning(
358
+ "The following part of your input was truncated because `max_sequence_length` is set to "
359
+ f" {max_sequence_length} tokens: {removed_text}"
360
+ )
361
+
362
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
363
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
364
+
365
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
366
+ _, seq_len, _ = prompt_embeds.shape
367
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
368
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
369
+
370
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
371
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
372
+
373
+ return prompt_embeds, prompt_attention_mask
374
+
375
+ # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt
376
+ def encode_prompt(
377
+ self,
378
+ prompt: Union[str, List[str]],
379
+ negative_prompt: Optional[Union[str, List[str]]] = None,
380
+ do_classifier_free_guidance: bool = True,
381
+ num_videos_per_prompt: int = 1,
382
+ prompt_embeds: Optional[torch.Tensor] = None,
383
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
384
+ prompt_attention_mask: Optional[torch.Tensor] = None,
385
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
386
+ max_sequence_length: int = 256,
387
+ device: Optional[torch.device] = None,
388
+ dtype: Optional[torch.dtype] = None,
389
+ ):
390
+ r"""
391
+ Encodes the prompt into text encoder hidden states.
392
+
393
+ Args:
394
+ prompt (`str` or `List[str]`, *optional*):
395
+ prompt to be encoded
396
+ negative_prompt (`str` or `List[str]`, *optional*):
397
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
398
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
399
+ less than `1`).
400
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
401
+ Whether to use classifier free guidance or not.
402
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
403
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
404
+ prompt_embeds (`torch.Tensor`, *optional*):
405
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
406
+ provided, text embeddings will be generated from `prompt` input argument.
407
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
408
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
409
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
410
+ argument.
411
+ device: (`torch.device`, *optional*):
412
+ torch device
413
+ dtype: (`torch.dtype`, *optional*):
414
+ torch dtype
415
+ """
416
+ device = device or self._execution_device
417
+
418
+ prompt = [prompt] if isinstance(prompt, str) else prompt
419
+ if prompt is not None:
420
+ batch_size = len(prompt)
421
+ else:
422
+ batch_size = prompt_embeds.shape[0]
423
+
424
+ if prompt_embeds is None:
425
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
426
+ prompt=prompt,
427
+ num_videos_per_prompt=num_videos_per_prompt,
428
+ max_sequence_length=max_sequence_length,
429
+ device=device,
430
+ dtype=dtype,
431
+ )
432
+
433
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
434
+ negative_prompt = negative_prompt or ""
435
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
436
+
437
+ if prompt is not None and type(prompt) is not type(negative_prompt):
438
+ raise TypeError(
439
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
440
+ f" {type(prompt)}."
441
+ )
442
+ elif batch_size != len(negative_prompt):
443
+ raise ValueError(
444
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
445
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
446
+ " the batch size of `prompt`."
447
+ )
448
+
449
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
450
+ prompt=negative_prompt,
451
+ num_videos_per_prompt=num_videos_per_prompt,
452
+ max_sequence_length=max_sequence_length,
453
+ device=device,
454
+ dtype=dtype,
455
+ )
456
+
457
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
458
+
459
+ def check_inputs(
460
+ self,
461
+ prompt,
462
+ conditions,
463
+ image,
464
+ video,
465
+ frame_index,
466
+ strength,
467
+ denoise_strength,
468
+ height,
469
+ width,
470
+ callback_on_step_end_tensor_inputs=None,
471
+ prompt_embeds=None,
472
+ negative_prompt_embeds=None,
473
+ prompt_attention_mask=None,
474
+ negative_prompt_attention_mask=None,
475
+ reference_video=None,
476
+ ):
477
+ if height % 32 != 0 or width % 32 != 0:
478
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
479
+
480
+ if callback_on_step_end_tensor_inputs is not None and not all(
481
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
482
+ ):
483
+ raise ValueError(
484
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
485
+ )
486
+
487
+ if prompt is not None and prompt_embeds is not None:
488
+ raise ValueError(
489
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
490
+ " only forward one of the two."
491
+ )
492
+ elif prompt is None and prompt_embeds is None:
493
+ raise ValueError(
494
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
495
+ )
496
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
497
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
498
+
499
+ if prompt_embeds is not None and prompt_attention_mask is None:
500
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
501
+
502
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
503
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
504
+
505
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
506
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
507
+ raise ValueError(
508
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
509
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
510
+ f" {negative_prompt_embeds.shape}."
511
+ )
512
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
513
+ raise ValueError(
514
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
515
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
516
+ f" {negative_prompt_attention_mask.shape}."
517
+ )
518
+
519
+ if conditions is not None and (image is not None or video is not None):
520
+ raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
521
+
522
+ if conditions is None:
523
+ if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
524
+ raise ValueError(
525
+ "If `conditions` is not provided, `image` and `frame_index` must be of the same length."
526
+ )
527
+ elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength):
528
+ raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.")
529
+ elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index):
530
+ raise ValueError(
531
+ "If `conditions` is not provided, `video` and `frame_index` must be of the same length."
532
+ )
533
+ elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
534
+ raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
535
+
536
+ if denoise_strength < 0 or denoise_strength > 1:
537
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
538
+
539
+ if reference_video is not None:
540
+ if not isinstance(reference_video, torch.Tensor):
541
+ raise ValueError(
542
+ "`reference_video` must be a torch.Tensor with shape [F, C, H, W] as returned by read_video()."
543
+ )
544
+ if reference_video.ndim != 4:
545
+ raise ValueError(
546
+ f"`reference_video` must be a 4D tensor with shape [F, C, H, W], but got shape {reference_video.shape}."
547
+ )
548
+
549
+ @staticmethod
550
+ def _prepare_video_ids(
551
+ batch_size: int,
552
+ num_frames: int,
553
+ height: int,
554
+ width: int,
555
+ patch_size: int = 1,
556
+ patch_size_t: int = 1,
557
+ device: torch.device = None,
558
+ ) -> torch.Tensor:
559
+ latent_sample_coords = torch.meshgrid(
560
+ torch.arange(0, num_frames, patch_size_t, device=device),
561
+ torch.arange(0, height, patch_size, device=device),
562
+ torch.arange(0, width, patch_size, device=device),
563
+ indexing="ij",
564
+ )
565
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
566
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
567
+ latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)
568
+
569
+ return latent_coords
570
+
571
+ @staticmethod
572
+ def _scale_video_ids(
573
+ video_ids: torch.Tensor,
574
+ scale_factor: int = 32,
575
+ scale_factor_t: int = 8,
576
+ frame_index: int = 0,
577
+ device: torch.device = None,
578
+ ) -> torch.Tensor:
579
+ scaled_latent_coords = (
580
+ video_ids
581
+ * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None]
582
+ )
583
+ scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0)
584
+ scaled_latent_coords[:, 0] += frame_index
585
+
586
+ return scaled_latent_coords
587
+
588
+
589
+
590
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
591
+ @staticmethod
592
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
593
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
594
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
595
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
596
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
597
+ batch_size, num_channels, num_frames, height, width = latents.shape
598
+ post_patch_num_frames = num_frames // patch_size_t
599
+ post_patch_height = height // patch_size
600
+ post_patch_width = width // patch_size
601
+ latents = latents.reshape(
602
+ batch_size,
603
+ -1,
604
+ post_patch_num_frames,
605
+ patch_size_t,
606
+ post_patch_height,
607
+ patch_size,
608
+ post_patch_width,
609
+ patch_size,
610
+ )
611
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
612
+ return latents
613
+
614
+ @staticmethod
615
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
616
+ def _unpack_latents(
617
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
618
+ ) -> torch.Tensor:
619
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
620
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
621
+ # what happens in the `_pack_latents` method.
622
+ batch_size = latents.size(0)
623
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
624
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
625
+ return latents
626
+
627
+ @staticmethod
628
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
629
+ def _normalize_latents(
630
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
631
+ ) -> torch.Tensor:
632
+ # Normalize latents across the channel dimension [B, C, F, H, W]
633
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
634
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
635
+ latents = (latents - latents_mean) * scaling_factor / latents_std
636
+ return latents
637
+
638
+ @staticmethod
639
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
640
+ def _denormalize_latents(
641
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
642
+ ) -> torch.Tensor:
643
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
644
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
645
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
646
+ latents = latents * latents_std / scaling_factor + latents_mean
647
+ return latents
648
+
649
+ def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int):
650
+ """
651
+ Trim a conditioning sequence to the allowed number of frames.
652
+
653
+ Args:
654
+ start_frame (int): The target frame number of the first frame in the sequence.
655
+ sequence_num_frames (int): The number of frames in the sequence.
656
+ target_num_frames (int): The target number of frames in the generated video.
657
+ Returns:
658
+ int: updated sequence length
659
+ """
660
+ scale_factor = self.vae_temporal_compression_ratio
661
+ num_frames = min(sequence_num_frames, target_num_frames - start_frame)
662
+ # Trim down to a multiple of temporal_scale_factor frames plus 1
663
+ num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
664
+ return num_frames
665
+
666
+ @staticmethod
667
+ def add_noise_to_image_conditioning_latents(
668
+ t: float,
669
+ init_latents: torch.Tensor,
670
+ latents: torch.Tensor,
671
+ noise_scale: float,
672
+ conditioning_mask: torch.Tensor,
673
+ generator,
674
+ eps=1e-6,
675
+ ):
676
+ """
677
+ Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially
678
+ when conditioned on a single frame.
679
+ """
680
+ noise = randn_tensor(
681
+ latents.shape,
682
+ generator=generator,
683
+ device=latents.device,
684
+ dtype=latents.dtype,
685
+ )
686
+ # Add noise only to hard-conditioning latents (conditioning_mask = 1.0)
687
+ need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1)
688
+ noised_latents = init_latents + noise_scale * noise * (t**2)
689
+ latents = torch.where(need_to_noise, noised_latents, latents)
690
+ return latents
691
+
692
+ def prepare_latents(
693
+ self,
694
+ conditions: Optional[List[torch.Tensor]] = None,
695
+ condition_strength: Optional[List[float]] = None,
696
+ condition_frame_index: Optional[List[int]] = None,
697
+ batch_size: int = 1,
698
+ num_channels_latents: int = 128,
699
+ height: int = 512,
700
+ width: int = 704,
701
+ num_frames: int = 161,
702
+ num_prefix_latent_frames: int = 2,
703
+ sigma: Optional[torch.Tensor] = None,
704
+ latents: Optional[torch.Tensor] = None,
705
+ generator: Optional[torch.Generator] = None,
706
+ device: Optional[torch.device] = None,
707
+ dtype: Optional[torch.dtype] = None,
708
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
709
+ num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
710
+ latent_height = height // self.vae_spatial_compression_ratio
711
+ latent_width = width // self.vae_spatial_compression_ratio
712
+
713
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
714
+
715
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
716
+ if latents is not None and sigma is not None:
717
+ if latents.shape != shape:
718
+ raise ValueError(
719
+ f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
720
+ )
721
+ latents = latents.to(device=device, dtype=dtype)
722
+ sigma = sigma.to(device=device, dtype=dtype)
723
+ latents = sigma * noise + (1 - sigma) * latents
724
+ else:
725
+ latents = noise
726
+
727
+ if len(conditions) > 0:
728
+ condition_latent_frames_mask = torch.zeros(
729
+ (batch_size, num_latent_frames), device=device, dtype=torch.float32
730
+ )
731
+
732
+ extra_conditioning_latents = []
733
+ extra_conditioning_video_ids = []
734
+ extra_conditioning_mask = []
735
+ extra_conditioning_num_latents = 0
736
+ for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index, strict=False):
737
+ condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
738
+ condition_latents = self._normalize_latents(
739
+ condition_latents, self.vae.latents_mean, self.vae.latents_std
740
+ ).to(device, dtype=dtype)
741
+
742
+ num_data_frames = data.size(2)
743
+ num_cond_frames = condition_latents.size(2)
744
+
745
+ if frame_index == 0:
746
+ latents[:, :, :num_cond_frames] = torch.lerp(
747
+ latents[:, :, :num_cond_frames], condition_latents, strength
748
+ )
749
+ condition_latent_frames_mask[:, :num_cond_frames] = strength
750
+
751
+ else:
752
+ if num_data_frames > 1:
753
+ if num_cond_frames < num_prefix_latent_frames:
754
+ raise ValueError(
755
+ f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
756
+ )
757
+
758
+ if num_cond_frames > num_prefix_latent_frames:
759
+ start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
760
+ end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
761
+ latents[:, :, start_frame:end_frame] = torch.lerp(
762
+ latents[:, :, start_frame:end_frame],
763
+ condition_latents[:, :, num_prefix_latent_frames:],
764
+ strength,
765
+ )
766
+ condition_latent_frames_mask[:, start_frame:end_frame] = strength
767
+ condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
768
+
769
+ noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
770
+ condition_latents = torch.lerp(noise, condition_latents, strength)
771
+
772
+ condition_video_ids = self._prepare_video_ids(
773
+ batch_size,
774
+ condition_latents.size(2),
775
+ latent_height,
776
+ latent_width,
777
+ patch_size=self.transformer_spatial_patch_size,
778
+ patch_size_t=self.transformer_temporal_patch_size,
779
+ device=device,
780
+ )
781
+ condition_video_ids = self._scale_video_ids(
782
+ condition_video_ids,
783
+ scale_factor=self.vae_spatial_compression_ratio,
784
+ scale_factor_t=self.vae_temporal_compression_ratio,
785
+ frame_index=frame_index,
786
+ device=device,
787
+ )
788
+ condition_latents = self._pack_latents(
789
+ condition_latents,
790
+ self.transformer_spatial_patch_size,
791
+ self.transformer_temporal_patch_size,
792
+ )
793
+ condition_conditioning_mask = torch.full(
794
+ condition_latents.shape[:2], strength, device=device, dtype=dtype
795
+ )
796
+
797
+ extra_conditioning_latents.append(condition_latents)
798
+ extra_conditioning_video_ids.append(condition_video_ids)
799
+ extra_conditioning_mask.append(condition_conditioning_mask)
800
+ extra_conditioning_num_latents += condition_latents.size(1)
801
+
802
+ video_ids = self._prepare_video_ids(
803
+ batch_size,
804
+ num_latent_frames,
805
+ latent_height,
806
+ latent_width,
807
+ patch_size_t=self.transformer_temporal_patch_size,
808
+ patch_size=self.transformer_spatial_patch_size,
809
+ device=device,
810
+ )
811
+ if len(conditions) > 0:
812
+ conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
813
+ else:
814
+ conditioning_mask, extra_conditioning_num_latents = None, 0
815
+ video_ids = self._scale_video_ids(
816
+ video_ids,
817
+ scale_factor=self.vae_spatial_compression_ratio,
818
+ scale_factor_t=self.vae_temporal_compression_ratio,
819
+ frame_index=0,
820
+ device=device,
821
+ )
822
+ latents = self._pack_latents(
823
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
824
+ )
825
+
826
+ if len(conditions) > 0 and len(extra_conditioning_latents) > 0:
827
+ latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
828
+ video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
829
+ conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
830
+
831
+ return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
832
+
833
+ def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
834
+ num_steps = min(int(num_inference_steps * strength), num_inference_steps)
835
+ start_index = max(num_inference_steps - num_steps, 0)
836
+ sigmas = sigmas[start_index:]
837
+ timesteps = timesteps[start_index:]
838
+ return sigmas, timesteps, num_inference_steps - start_index
839
+
840
+ @property
841
+ def guidance_scale(self):
842
+ return self._guidance_scale
843
+
844
+ @property
845
+ def guidance_rescale(self):
846
+ return self._guidance_rescale
847
+
848
+ @property
849
+ def do_classifier_free_guidance(self):
850
+ return self._guidance_scale > 1.0
851
+
852
+ @property
853
+ def num_timesteps(self):
854
+ return self._num_timesteps
855
+
856
+ @property
857
+ def current_timestep(self):
858
+ return self._current_timestep
859
+
860
+ @property
861
+ def attention_kwargs(self):
862
+ return self._attention_kwargs
863
+
864
+ @property
865
+ def interrupt(self):
866
+ return self._interrupt
867
+
868
+ @torch.no_grad()
869
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
870
+ def __call__(
871
+ self,
872
+ conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None,
873
+ image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
874
+ video: List[PipelineImageInput] = None,
875
+ frame_index: Union[int, List[int]] = 0,
876
+ strength: Union[float, List[float]] = 1.0,
877
+ denoise_strength: float = 1.0,
878
+ prompt: Union[str, List[str]] = None,
879
+ negative_prompt: Optional[Union[str, List[str]]] = None,
880
+ height: int = 512,
881
+ width: int = 704,
882
+ num_frames: int = 161,
883
+ frame_rate: int = 25,
884
+ num_inference_steps: int = 50,
885
+ timesteps: List[int] = None,
886
+ guidance_scale: float = 3,
887
+ guidance_rescale: float = 0.0,
888
+ image_cond_noise_scale: float = 0.15,
889
+ num_videos_per_prompt: Optional[int] = 1,
890
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
891
+ latents: Optional[torch.Tensor] = None,
892
+ reference_video: Optional[torch.Tensor] = None,
893
+ output_reference_comparison: bool = False,
894
+ prompt_embeds: Optional[torch.Tensor] = None,
895
+ prompt_attention_mask: Optional[torch.Tensor] = None,
896
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
897
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
898
+ decode_timestep: Union[float, List[float]] = 0.0,
899
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
900
+ output_type: Optional[str] = "pil",
901
+ return_dict: bool = True,
902
+ attention_kwargs: Optional[Dict[str, Any]] = None,
903
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
904
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
905
+ max_sequence_length: int = 256,
906
+ ):
907
+ r"""
908
+ Function invoked when calling the pipeline for generation.
909
+
910
+ Args:
911
+ conditions (`List[LTXVideoCondition], *optional*`):
912
+ The list of frame-conditioning items for the video generation.If not provided, conditions will be
913
+ created using `image`, `video`, `frame_index` and `strength`.
914
+ image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
915
+ The image or images to condition the video generation. If not provided, one has to pass `video` or
916
+ `conditions`.
917
+ video (`List[PipelineImageInput]`, *optional*):
918
+ The video to condition the video generation. If not provided, one has to pass `image` or `conditions`.
919
+ frame_index (`int` or `List[int]`, *optional*):
920
+ The frame index or frame indices at which the image or video will conditionally effect the video
921
+ generation. If not provided, one has to pass `conditions`.
922
+ strength (`float` or `List[float]`, *optional*):
923
+ The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
924
+ denoise_strength (`float`, defaults to `1.0`):
925
+ The strength of the noise added to the latents for editing. Higher strength leads to more noise added
926
+ to the latents, therefore leading to more differences between original video and generated video. This
927
+ is useful for video-to-video editing.
928
+ prompt (`str` or `List[str]`, *optional*):
929
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
930
+ instead.
931
+ height (`int`, defaults to `512`):
932
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
933
+ width (`int`, defaults to `704`):
934
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
935
+ num_frames (`int`, defaults to `161`):
936
+ The number of video frames to generate
937
+ num_inference_steps (`int`, *optional*, defaults to 50):
938
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
939
+ expense of slower inference.
940
+ timesteps (`List[int]`, *optional*):
941
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
942
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
943
+ passed will be used. Must be in descending order.
944
+ guidance_scale (`float`, defaults to `3 `):
945
+ Guidance scale as defined in [Classifier-Free Diffusion
946
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
947
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
948
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
949
+ the text `prompt`, usually at the expense of lower image quality.
950
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
951
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
952
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
953
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
954
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
955
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
956
+ The number of videos to generate per prompt.
957
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
958
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
959
+ to make generation deterministic.
960
+ latents (`torch.Tensor`, *optional*):
961
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
962
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
963
+ tensor will ge generated by sampling using the supplied random `generator`.
964
+ reference_video (`torch.Tensor`, *optional*):
965
+ An optional reference video to guide the generation process. Should be a tensor with shape
966
+ [F, C, H, W] in range [0, 1] as returned by `read_video()` from video_utils. The reference video
967
+ will be encoded and concatenated to the latent sequence, providing global guidance while remaining
968
+ unchanged during denoising. The reference video can be of any size and will be automatically
969
+ resized and cropped to match the target dimensions.
970
+ output_reference_comparison (`bool`, defaults to `False`):
971
+ Whether to output a side-by-side comparison showing both the reference video (if provided) and the
972
+ generated video. If `False`, only the generated video is returned. Only applies when `reference_video`
973
+ is provided.
974
+ prompt_embeds (`torch.Tensor`, *optional*):
975
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
976
+ provided, text embeddings will be generated from `prompt` input argument.
977
+ prompt_attention_mask (`torch.Tensor`, *optional*):
978
+ Pre-generated attention mask for text embeddings.
979
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
980
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
981
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
982
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
983
+ Pre-generated attention mask for negative text embeddings.
984
+ decode_timestep (`float`, defaults to `0.0`):
985
+ The timestep at which generated video is decoded.
986
+ decode_noise_scale (`float`, defaults to `None`):
987
+ The interpolation factor between random noise and denoised latents at the decode timestep.
988
+ output_type (`str`, *optional*, defaults to `"pil"`):
989
+ The output format of the generate image. Choose between
990
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
991
+ return_dict (`bool`, *optional*, defaults to `True`):
992
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
993
+ attention_kwargs (`dict`, *optional*):
994
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
995
+ `self.processor` in
996
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
997
+ callback_on_step_end (`Callable`, *optional*):
998
+ A function that calls at the end of each denoising steps during the inference. The function is called
999
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1000
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1001
+ `callback_on_step_end_tensor_inputs`.
1002
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1003
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1004
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1005
+ `._callback_tensor_inputs` attribute of your pipeline class.
1006
+ max_sequence_length (`int` defaults to `128 `):
1007
+ Maximum sequence length to use with the `prompt`.
1008
+
1009
+ Examples:
1010
+
1011
+ Returns:
1012
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
1013
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
1014
+ returned where the first element is a list with the generated images.
1015
+ """
1016
+
1017
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1018
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1019
+ # if latents is not None:
1020
+ # raise ValueError("Passing latents is not yet supported.")
1021
+
1022
+ # 1. Check inputs. Raise error if not correct
1023
+ self.check_inputs(
1024
+ prompt=prompt,
1025
+ conditions=conditions,
1026
+ image=image,
1027
+ video=video,
1028
+ frame_index=frame_index,
1029
+ strength=strength,
1030
+ denoise_strength=denoise_strength,
1031
+ height=height,
1032
+ width=width,
1033
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1034
+ prompt_embeds=prompt_embeds,
1035
+ negative_prompt_embeds=negative_prompt_embeds,
1036
+ prompt_attention_mask=prompt_attention_mask,
1037
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
1038
+ reference_video=reference_video,
1039
+ )
1040
+
1041
+ self._guidance_scale = guidance_scale
1042
+ self._guidance_rescale = guidance_rescale
1043
+ self._attention_kwargs = attention_kwargs
1044
+ self._interrupt = False
1045
+ self._current_timestep = None
1046
+
1047
+ # 2. Define call parameters
1048
+ if prompt is not None and isinstance(prompt, str):
1049
+ batch_size = 1
1050
+ elif prompt is not None and isinstance(prompt, list):
1051
+ batch_size = len(prompt)
1052
+ else:
1053
+ batch_size = prompt_embeds.shape[0]
1054
+
1055
+ if conditions is not None:
1056
+ if not isinstance(conditions, list):
1057
+ conditions = [conditions]
1058
+
1059
+ strength = [condition.strength for condition in conditions]
1060
+ frame_index = [condition.frame_index for condition in conditions]
1061
+ image = [condition.image for condition in conditions]
1062
+ video = [condition.video for condition in conditions]
1063
+ elif image is not None or video is not None:
1064
+ if not isinstance(image, list):
1065
+ image = [image]
1066
+ num_conditions = 1
1067
+ elif isinstance(image, list):
1068
+ num_conditions = len(image)
1069
+ if not isinstance(video, list):
1070
+ video = [video]
1071
+ num_conditions = 1
1072
+ elif isinstance(video, list):
1073
+ num_conditions = len(video)
1074
+
1075
+ if not isinstance(frame_index, list):
1076
+ frame_index = [frame_index] * num_conditions
1077
+ if not isinstance(strength, list):
1078
+ strength = [strength] * num_conditions
1079
+
1080
+ device = self._execution_device
1081
+ vae_dtype = self.vae.dtype
1082
+
1083
+ # 3. Prepare text embeddings & conditioning image/video
1084
+ (
1085
+ prompt_embeds,
1086
+ prompt_attention_mask,
1087
+ negative_prompt_embeds,
1088
+ negative_prompt_attention_mask,
1089
+ ) = self.encode_prompt(
1090
+ prompt=prompt,
1091
+ negative_prompt=negative_prompt,
1092
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1093
+ num_videos_per_prompt=num_videos_per_prompt,
1094
+ prompt_embeds=prompt_embeds,
1095
+ negative_prompt_embeds=negative_prompt_embeds,
1096
+ prompt_attention_mask=prompt_attention_mask,
1097
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
1098
+ max_sequence_length=max_sequence_length,
1099
+ device=device,
1100
+ )
1101
+ if self.do_classifier_free_guidance:
1102
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1103
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
1104
+
1105
+ conditioning_tensors = []
1106
+ is_conditioning_image_or_video = image is not None or video is not None
1107
+ if is_conditioning_image_or_video:
1108
+ for condition_image, condition_video, condition_frame_index, condition_strength in zip(
1109
+ image, video, frame_index, strength, strict=False
1110
+ ):
1111
+ if condition_image is not None:
1112
+ condition_tensor = (
1113
+ self.video_processor.preprocess(condition_image, height, width)
1114
+ .unsqueeze(2)
1115
+ .to(device, dtype=vae_dtype)
1116
+ )
1117
+ elif condition_video is not None:
1118
+ condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
1119
+ num_frames_input = condition_tensor.size(2)
1120
+ num_frames_output = self.trim_conditioning_sequence(
1121
+ condition_frame_index, num_frames_input, num_frames
1122
+ )
1123
+ condition_tensor = condition_tensor[:, :, :num_frames_output]
1124
+ condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
1125
+ else:
1126
+ raise ValueError("Either `image` or `video` must be provided for conditioning.")
1127
+
1128
+ if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
1129
+ raise ValueError(
1130
+ f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
1131
+ f"but got {condition_tensor.size(2)} frames."
1132
+ )
1133
+ conditioning_tensors.append(condition_tensor)
1134
+
1135
+ # 4. Prepare timesteps
1136
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1137
+ latent_height = height // self.vae_spatial_compression_ratio
1138
+ latent_width = width // self.vae_spatial_compression_ratio
1139
+ if timesteps is None:
1140
+ sigmas = linear_quadratic_schedule(num_inference_steps)
1141
+ timesteps = sigmas * 1000
1142
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1143
+ sigmas = self.scheduler.sigmas
1144
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1145
+
1146
+ latent_sigma = None
1147
+ if denoise_strength < 1:
1148
+ sigmas, timesteps, num_inference_steps = self.get_timesteps(
1149
+ sigmas, timesteps, num_inference_steps, denoise_strength
1150
+ )
1151
+ latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
1152
+
1153
+ self._num_timesteps = len(timesteps)
1154
+
1155
+ # 5. Prepare latent variables
1156
+ num_channels_latents = self.transformer.config.in_channels
1157
+ latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
1158
+ conditioning_tensors,
1159
+ strength,
1160
+ frame_index,
1161
+ batch_size=batch_size * num_videos_per_prompt,
1162
+ num_channels_latents=num_channels_latents,
1163
+ height=height,
1164
+ width=width,
1165
+ num_frames=num_frames,
1166
+ sigma=latent_sigma,
1167
+ latents=latents,
1168
+ generator=generator,
1169
+ device=device,
1170
+ dtype=torch.float32,
1171
+ )
1172
+
1173
+ # 4.5. Process reference video (if provided) and concatenate at the beginning
1174
+ reference_latents = None
1175
+ reference_num_latents = 0
1176
+ if reference_video is not None:
1177
+ # Work with the original tensor format [F, C, H, W]
1178
+ ref_frames = reference_video # [F, C, H, W]
1179
+
1180
+ # Resize maintaining aspect ratio (resize all frames)
1181
+ current_height, current_width = ref_frames.shape[2:]
1182
+ aspect_ratio = current_width / current_height
1183
+ target_aspect_ratio = width / height
1184
+
1185
+ if aspect_ratio > target_aspect_ratio:
1186
+ # Width is relatively larger, resize based on height
1187
+ resize_height = height
1188
+ resize_width = int(resize_height * aspect_ratio)
1189
+ else:
1190
+ # Height is relatively larger, resize based on width
1191
+ resize_width = width
1192
+ resize_height = int(resize_width / aspect_ratio)
1193
+
1194
+ ref_frames = resize(ref_frames, [resize_height, resize_width], antialias=True)
1195
+
1196
+ # Center crop to target dimensions
1197
+ ref_frames = center_crop(ref_frames, [height, width])
1198
+
1199
+ # Convert to VAE input format: [1, C, F, H, W] and proper range [-1, 1]
1200
+ reference_tensor = ref_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [1, F, C, H, W] -> [1, C, F, H, W]
1201
+ reference_tensor = reference_tensor * 2.0 - 1.0 # [0, 1] -> [-1, 1]
1202
+
1203
+ # Trim reference video to proper frame count for temporal compression
1204
+ ref_num_frames_input = reference_tensor.size(2)
1205
+ ref_num_frames_output = self.trim_conditioning_sequence(0, ref_num_frames_input, num_frames)
1206
+ reference_tensor = reference_tensor[:, :, :ref_num_frames_output]
1207
+ reference_tensor = reference_tensor.to(device, dtype=vae_dtype)
1208
+
1209
+ # Ensure proper frame count for VAE temporal compression
1210
+ if reference_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
1211
+ # Trim to make it compatible with temporal compression
1212
+ ref_frames_to_keep = (
1213
+ (reference_tensor.size(2) - 1) // self.vae_temporal_compression_ratio
1214
+ ) * self.vae_temporal_compression_ratio + 1
1215
+ reference_tensor = reference_tensor[:, :, :ref_frames_to_keep]
1216
+
1217
+ # Expand reference tensor for batch and num_videos_per_prompt
1218
+ reference_tensor = reference_tensor.repeat(batch_size * num_videos_per_prompt, 1, 1, 1, 1)
1219
+
1220
+ # Encode reference video to latents
1221
+ reference_latents = retrieve_latents(self.vae.encode(reference_tensor), generator=generator)
1222
+ reference_latents = self._normalize_latents(
1223
+ reference_latents, self.vae.latents_mean, self.vae.latents_std
1224
+ ).to(device, dtype=torch.float32)
1225
+
1226
+ # Create "clean" coordinates for reference video (as if no frame conditioning applied)
1227
+ ref_latent_frames = reference_latents.size(2)
1228
+ ref_latent_height = reference_latents.size(3)
1229
+ ref_latent_width = reference_latents.size(4)
1230
+
1231
+ reference_video_coords = self._prepare_video_ids(
1232
+ batch_size * num_videos_per_prompt,
1233
+ ref_latent_frames,
1234
+ ref_latent_height,
1235
+ ref_latent_width,
1236
+ patch_size_t=self.transformer_temporal_patch_size,
1237
+ patch_size=self.transformer_spatial_patch_size,
1238
+ device=device,
1239
+ )
1240
+ reference_video_coords = self._scale_video_ids(
1241
+ reference_video_coords,
1242
+ scale_factor=self.vae_spatial_compression_ratio,
1243
+ scale_factor_t=self.vae_temporal_compression_ratio,
1244
+ frame_index=0, # Reference video starts at frame 0
1245
+ device=device,
1246
+ )
1247
+
1248
+ # Pack reference latents
1249
+ reference_latents = self._pack_latents(
1250
+ reference_latents,
1251
+ self.transformer_spatial_patch_size,
1252
+ self.transformer_temporal_patch_size,
1253
+ )
1254
+ reference_num_latents = reference_latents.size(1)
1255
+
1256
+ # Concatenate reference latents at the beginning: [reference_latents, frame_conditions, target_latents]
1257
+ latents = torch.cat([reference_latents, latents], dim=1)
1258
+
1259
+ # Update video coordinates: [reference_coords, existing_coords]
1260
+ reference_video_coords = reference_video_coords.float()
1261
+ video_coords = torch.cat([reference_video_coords, video_coords], dim=2)
1262
+ video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
1263
+
1264
+ # Update conditioning mask to include reference (frozen = strength 1.0)
1265
+ if conditioning_mask is not None:
1266
+ reference_conditioning_mask = torch.ones(
1267
+ (batch_size * num_videos_per_prompt, reference_num_latents), device=device, dtype=torch.float32
1268
+ )
1269
+ conditioning_mask = torch.cat([reference_conditioning_mask, conditioning_mask], dim=1)
1270
+ else:
1271
+ # If no frame conditioning, still create mask for reference
1272
+ conditioning_mask = torch.ones(
1273
+ (batch_size * num_videos_per_prompt, reference_num_latents), device=device, dtype=torch.float32
1274
+ )
1275
+ # Add zeros for target latents
1276
+ target_conditioning_mask = torch.zeros(
1277
+ (batch_size * num_videos_per_prompt, latents.size(1) - reference_num_latents),
1278
+ device=device,
1279
+ dtype=torch.float32,
1280
+ )
1281
+ conditioning_mask = torch.cat([conditioning_mask, target_conditioning_mask], dim=1)
1282
+
1283
+ video_coords = video_coords.float()
1284
+ if reference_video is None:
1285
+ video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
1286
+
1287
+ init_latents = latents.clone() if is_conditioning_image_or_video or reference_video is not None else None
1288
+
1289
+ if self.do_classifier_free_guidance:
1290
+ video_coords = torch.cat([video_coords, video_coords], dim=0)
1291
+
1292
+ # 6. Denoising loop
1293
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1294
+ for i, t in enumerate(timesteps):
1295
+ if self.interrupt:
1296
+ continue
1297
+
1298
+ self._current_timestep = t
1299
+
1300
+ if image_cond_noise_scale > 0 and init_latents is not None:
1301
+ # Add timestep-dependent noise to the hard-conditioning latents
1302
+ # This helps with motion continuity, especially when conditioned on a single frame
1303
+ latents = self.add_noise_to_image_conditioning_latents(
1304
+ t / 1000.0,
1305
+ init_latents,
1306
+ latents,
1307
+ image_cond_noise_scale,
1308
+ conditioning_mask,
1309
+ generator,
1310
+ )
1311
+
1312
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1313
+ if is_conditioning_image_or_video or reference_video is not None:
1314
+ conditioning_mask_model_input = (
1315
+ torch.cat([conditioning_mask, conditioning_mask])
1316
+ if self.do_classifier_free_guidance
1317
+ else conditioning_mask
1318
+ )
1319
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1320
+
1321
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1322
+ timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1323
+ if is_conditioning_image_or_video or reference_video is not None:
1324
+ timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
1325
+
1326
+ noise_pred = self.transformer(
1327
+ hidden_states=latent_model_input,
1328
+ encoder_hidden_states=prompt_embeds,
1329
+ timestep=timestep,
1330
+ encoder_attention_mask=prompt_attention_mask,
1331
+ video_coords=video_coords,
1332
+ attention_kwargs=attention_kwargs,
1333
+ return_dict=False,
1334
+ )[0]
1335
+
1336
+ if self.do_classifier_free_guidance:
1337
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1338
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1339
+ timestep, _ = timestep.chunk(2)
1340
+
1341
+ if self.guidance_rescale > 0:
1342
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1343
+ noise_pred = rescale_noise_cfg(
1344
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
1345
+ )
1346
+
1347
+ denoised_latents = self.scheduler.step(
1348
+ -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
1349
+ )[0]
1350
+ if is_conditioning_image_or_video or reference_video is not None:
1351
+ tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
1352
+ latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1353
+ else:
1354
+ latents = denoised_latents
1355
+
1356
+ if callback_on_step_end is not None:
1357
+ callback_kwargs = {}
1358
+ for k in callback_on_step_end_tensor_inputs:
1359
+ callback_kwargs[k] = locals()[k]
1360
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1361
+
1362
+ latents = callback_outputs.pop("latents", latents)
1363
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1364
+
1365
+ # call the callback, if provided
1366
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1367
+ progress_bar.update()
1368
+
1369
+ if XLA_AVAILABLE:
1370
+ xm.mark_step()
1371
+
1372
+ # Handle reference video output processing
1373
+ if reference_video is not None and output_reference_comparison:
1374
+ # Split latents: [reference_latents, frame_conditions, target_latents]
1375
+ reference_latents_out = latents[:, :reference_num_latents]
1376
+ remaining_latents = latents[:, reference_num_latents:]
1377
+
1378
+ # Remove frame conditioning from remaining latents if needed
1379
+ if is_conditioning_image_or_video:
1380
+ target_latents_out = remaining_latents[:, extra_conditioning_num_latents:]
1381
+ else:
1382
+ target_latents_out = remaining_latents
1383
+
1384
+ # Process both reference and target latents
1385
+ videos = []
1386
+ for curr_latents in [reference_latents_out, target_latents_out]:
1387
+ if output_type == "latent":
1388
+ curr_video = curr_latents
1389
+ else:
1390
+ curr_latents = self._unpack_latents(
1391
+ curr_latents,
1392
+ latent_num_frames,
1393
+ latent_height,
1394
+ latent_width,
1395
+ self.transformer_spatial_patch_size,
1396
+ self.transformer_temporal_patch_size,
1397
+ )
1398
+ curr_latents = self._denormalize_latents(
1399
+ curr_latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
1400
+ )
1401
+ curr_latents = curr_latents.to(prompt_embeds.dtype)
1402
+
1403
+ if not self.vae.config.timestep_conditioning:
1404
+ timestep = None
1405
+ else:
1406
+ noise = torch.randn(
1407
+ curr_latents.shape, generator=generator, device=device, dtype=curr_latents.dtype
1408
+ )
1409
+ if not isinstance(decode_timestep, list):
1410
+ decode_timestep = [decode_timestep] * batch_size
1411
+ if decode_noise_scale is None:
1412
+ decode_noise_scale = decode_timestep
1413
+ elif not isinstance(decode_noise_scale, list):
1414
+ decode_noise_scale = [decode_noise_scale] * batch_size
1415
+
1416
+ timestep = torch.tensor(decode_timestep, device=device, dtype=curr_latents.dtype)
1417
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=curr_latents.dtype)[
1418
+ :, None, None, None, None
1419
+ ]
1420
+ curr_latents = (1 - decode_noise_scale) * curr_latents + decode_noise_scale * noise
1421
+
1422
+ curr_video = self.vae.decode(curr_latents, timestep, return_dict=False)[0]
1423
+ curr_video = self.video_processor.postprocess_video(curr_video, output_type=output_type)
1424
+ videos.append(curr_video)
1425
+
1426
+ # Concatenate videos side-by-side (along width dimension for visual output)
1427
+ if output_type == "latent":
1428
+ video = torch.cat(videos, dim=0)
1429
+ # For video tensors, shape is [B, C, F, H, W] or list of PIL images
1430
+ elif isinstance(videos[0], list):
1431
+ # Handle PIL images case - concatenate each frame side by side
1432
+ video = []
1433
+ for batch_idx in range(len(videos[0])):
1434
+ combined_video = []
1435
+ for frame_idx in range(len(videos[0][batch_idx])):
1436
+ ref_frame = videos[0][batch_idx][frame_idx]
1437
+ gen_frame = videos[1][batch_idx][frame_idx]
1438
+ # Create side-by-side comparison
1439
+ import PIL.Image
1440
+
1441
+ if isinstance(ref_frame, PIL.Image.Image) and isinstance(gen_frame, PIL.Image.Image):
1442
+ combined_width = ref_frame.width + gen_frame.width
1443
+ combined_height = max(ref_frame.height, gen_frame.height)
1444
+ combined_frame = PIL.Image.new("RGB", (combined_width, combined_height))
1445
+ combined_frame.paste(ref_frame, (0, 0))
1446
+ combined_frame.paste(gen_frame, (ref_frame.width, 0))
1447
+ combined_video.append(combined_frame)
1448
+ else:
1449
+ combined_video.append(gen_frame) # Fallback to generated only
1450
+ video.append(combined_video)
1451
+ else:
1452
+ # Handle tensor case - concatenate along width dimension (dim=4)
1453
+ video = torch.cat(videos, dim=4)
1454
+ else:
1455
+ # Regular processing - just remove conditioning parts and output generated video
1456
+ if reference_video is not None:
1457
+ # Remove reference latents
1458
+ latents = latents[:, reference_num_latents:]
1459
+
1460
+ if is_conditioning_image_or_video:
1461
+ latents = latents[:, extra_conditioning_num_latents:]
1462
+
1463
+ latents = self._unpack_latents(
1464
+ latents,
1465
+ latent_num_frames,
1466
+ latent_height,
1467
+ latent_width,
1468
+ self.transformer_spatial_patch_size,
1469
+ self.transformer_temporal_patch_size,
1470
+ )
1471
+
1472
+ if output_type == "latent":
1473
+ video = latents
1474
+ else:
1475
+ latents = self._denormalize_latents(
1476
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
1477
+ )
1478
+ latents = latents.to(prompt_embeds.dtype)
1479
+
1480
+ if not self.vae.config.timestep_conditioning:
1481
+ timestep = None
1482
+ else:
1483
+ noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
1484
+ if not isinstance(decode_timestep, list):
1485
+ decode_timestep = [decode_timestep] * batch_size
1486
+ if decode_noise_scale is None:
1487
+ decode_noise_scale = decode_timestep
1488
+ elif not isinstance(decode_noise_scale, list):
1489
+ decode_noise_scale = [decode_noise_scale] * batch_size
1490
+
1491
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
1492
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
1493
+ :, None, None, None, None
1494
+ ]
1495
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
1496
+
1497
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
1498
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
1499
+
1500
+ # Offload all models
1501
+ self.maybe_free_model_hooks()
1502
+
1503
+ if not return_dict:
1504
+ return (video,)
1505
+
1506
+ return LTXPipelineOutput(frames=video)