| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import torchvision.transforms.functional as TF | |
| from diffusers import AutoencoderKLWan, WanImageToVideoPipeline | |
| from diffusers.utils import export_to_video, load_image | |
| from transformers import CLIPVisionModel | |
| def generate_video(first_frame_url, last_frame_url, prompt): | |
| model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" | |
| image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) | |
| vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers", | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| keep_in_fp32_modules=True | |
| ) | |
| pipe.to("cuda") | |
| first_frame = load_image(first_frame_url) | |
| last_frame = load_image(last_frame_url) | |
| def aspect_ratio_resize(image, pipe, max_area=720 * 1280): | |
| aspect_ratio = image.height / image.width | |
| mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] | |
| height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value | |
| width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value | |
| image = image.resize((width, height)) | |
| return image, height, width | |
| def center_crop_resize(image, height, width): | |
| resize_ratio = max(width / image.width, height / image.height) | |
| width = round(image.width * resize_ratio) | |
| height = round(image.height * resize_ratio) | |
| size = [width, height] | |
| image = TF.center_crop(image, size) | |
| return image, height, width | |
| first_frame, height, width = aspect_ratio_resize(first_frame, pipe) | |
| if last_frame.size != first_frame.size: | |
| last_frame, _, _ = center_crop_resize(last_frame, height, width) | |
| output = pipe( | |
| image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5 | |
| ).frames[0] | |
| video_path = "wan_output.mp4" | |
| export_to_video(output, video_path, fps=16) | |
| return video_path | |
| iface = gr.Interface( | |
| fn=generate_video, | |
| inputs=[ | |
| gr.Textbox(label="First Frame URL"), | |
| gr.Textbox(label="Last Frame URL"), | |
| gr.Textbox(label="Prompt") | |
| ], | |
| outputs=gr.Video(label="Generated Video"), | |
| title="Wan2.1 FLF2V Video Generator" | |
| ) | |
| iface.launch() | |