Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from diffusers import AutoencoderKLWan, WanImageToVideoPipeline | |
| from diffusers.utils import export_to_video, load_image | |
| from transformers import CLIPVisionModel | |
| from PIL import Image | |
| import tempfile | |
| # --- Load Model --- | |
| model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-Diffusers" | |
| image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) | |
| vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16 | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe.to(device) | |
| pipe.enable_model_cpu_offload() | |
| # --- Helper Functions --- | |
| def aspect_ratio_resize(image, pipe, max_area=720 * 1280): | |
| aspect_ratio = image.height / image.width | |
| mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size | |
| height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value | |
| width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value | |
| image = image.resize((width, height)) | |
| return image, height, width | |
| def center_crop_resize(image, height, width): | |
| import torchvision.transforms.functional as TF | |
| resize_ratio = max(width / image.width, height / image.height) | |
| width = round(image.width * resize_ratio) | |
| height = round(image.height * resize_ratio) | |
| size = [width, height] | |
| image = TF.center_crop(image, size) | |
| return image, height, width | |
| # --- Gradio Inference Function --- | |
| def infer(first_image, last_image, prompt, guidance=5.5, frames=25): | |
| # Convert to PIL | |
| if not isinstance(first_image, Image.Image): | |
| first_image = Image.fromarray(first_image) | |
| if not isinstance(last_image, Image.Image): | |
| last_image = Image.fromarray(last_image) | |
| # Resize/crop as needed | |
| first_image, height, width = aspect_ratio_resize(first_image, pipe) | |
| if last_image.size != first_image.size: | |
| last_image, _, _ = center_crop_resize(last_image, height, width) | |
| # Run pipeline | |
| output = pipe( | |
| image=[first_image, last_image], | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance, | |
| num_frames=frames, | |
| ).frames | |
| # Export to video | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| export_to_video(output, tmp.name, fps=16) | |
| return tmp.name | |
| # --- Gradio Interface --- | |
| demo = gr.Interface( | |
| fn=infer, | |
| inputs=[ | |
| gr.Image(type="pil", label="Start Frame"), | |
| gr.Image(type="pil", label="End Frame"), | |
| gr.Textbox(placeholder="Prompt (optional)", label="Prompt"), | |
| gr.Slider(3, 12, value=5.5, step=0.1, label="Guidance Scale"), | |
| gr.Slider(8, 48, value=25, step=1, label="Num Frames"), | |
| ], | |
| outputs=gr.Video(label="Generated Video"), | |
| title="WAN Two-Frame Video Interpolation", | |
| description="Upload two images and (optionally) a prompt to create a smooth video transition." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_api=True, delete_cache=(60, 60)) |