GeradeHouse commited on
Commit
47b7da6
·
verified ·
1 Parent(s): f40229f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -24
app.py CHANGED
@@ -1,7 +1,8 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
- Author: GeradeHouse
 
5
  """
6
 
7
  import numpy as np
@@ -16,46 +17,44 @@ import torchvision.transforms.functional as TF
16
  # ---------------------------------------------------------------------
17
  # CONFIG ----------------------------------------------------------------
18
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # or switch to 1.3B
19
- DTYPE = torch.float16 # or bfloat16
20
  MAX_AREA = 1280 * 720 # ≤720p
21
- DEFAULT_FRAMES = 81 # ~5s @16 fps
22
  # ----------------------------------------------------------------------
23
 
24
  def load_pipeline():
25
- """Lazy‐load & configure the pipeline once per process."""
26
- # 1) load the CLIP image encoder (full-precision)
27
  image_encoder = CLIPVisionModel.from_pretrained(
28
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
29
  )
30
- # 2) load the VAE (half-precision)
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
33
  )
34
- # 3) load the video pipeline
35
  pipe = WanImageToVideoPipeline.from_pretrained(
36
  MODEL_ID,
37
  vae=vae,
38
  image_encoder=image_encoder,
39
  torch_dtype=DTYPE,
 
 
40
  )
41
 
42
- # 4) override the processor with the fast Rust implementation
43
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
44
  MODEL_ID, subfolder="image_processor", use_fast=True
45
  )
46
 
47
- # 5) memory helpers (offload UNet to CPU as needed)
48
- # pipe.enable_model_cpu_offload()
49
- # (Removed pipe.vae.enable_slicing() — not supported on AutoencoderKLWan)
50
-
51
- return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
52
 
53
  PIPE = load_pipeline()
54
 
55
  # ----------------------------------------------------------------------
56
  # UTILS ----------------------------------------------------------------
57
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
58
- """Resize while keeping aspect & respecting patch multiples."""
59
  ar = img.height / img.width
60
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
61
  h = round(np.sqrt(max_area * ar)) // mod * mod
@@ -63,11 +62,10 @@ def aspect_resize(img: Image.Image, max_area=MAX_AREA):
63
  return img.resize((w, h), Image.LANCZOS), h, w
64
 
65
  def center_crop_resize(img: Image.Image, h, w):
66
- """Centercrop & resize to H×W."""
67
  ratio = max(w / img.width, h / img.height)
68
  img = img.resize(
69
- (round(img.width * ratio), round(img.height * ratio)),
70
- Image.LANCZOS
71
  )
72
  return TF.center_crop(img, [h, w])
73
 
@@ -76,7 +74,7 @@ def center_crop_resize(img: Image.Image, h, w):
76
  def generate(first_frame, last_frame, prompt, negative_prompt, steps,
77
  guidance, num_frames, seed, fps):
78
 
79
- # seed handling
80
  if seed == -1:
81
  seed = torch.seed()
82
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
@@ -86,7 +84,7 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
86
  if last_frame.size != first_frame.size:
87
  last_frame = center_crop_resize(last_frame, h, w)
88
 
89
- # run the pipeline
90
  output = PIPE(
91
  image=first_frame,
92
  last_image=last_frame,
@@ -99,9 +97,9 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
99
  guidance_scale=guidance,
100
  generator=gen,
101
  )
102
- frames = output.frames[0] # list of PIL Image
103
 
104
- # export to MP4
105
  video_path = export_to_video(frames, fps=fps)
106
  return video_path, seed
107
 
@@ -112,10 +110,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
112
 
113
  with gr.Row():
114
  first_img = gr.Image(label="First frame", type="pil")
115
- last_img = gr.Image(label="Last frame", type="pil")
116
 
117
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
118
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
119
 
120
  with gr.Accordion("Advanced parameters", open=False):
121
  steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
+ Uses Accelerate’s automatic device mapping for optimal CPU/GPU placement.
5
+ Author: <your-handle>
6
  """
7
 
8
  import numpy as np
 
17
  # ---------------------------------------------------------------------
18
  # CONFIG ----------------------------------------------------------------
19
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # or switch to 1.3B
20
+ DTYPE = torch.float16 # or torch.bfloat16
21
  MAX_AREA = 1280 * 720 # ≤720p
22
+ DEFAULT_FRAMES = 81 # ~5s @16fps
23
  # ----------------------------------------------------------------------
24
 
25
  def load_pipeline():
26
+ """Load & auto-map the pipeline across CPU/GPU with low CPU memory usage."""
27
+ # 1) load vision encoder (full precision)
28
  image_encoder = CLIPVisionModel.from_pretrained(
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
31
+ # 2) load VAE (half precision)
32
  vae = AutoencoderKLWan.from_pretrained(
33
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
34
  )
35
+ # 3) load the video pipeline with Accelerate helpers
36
  pipe = WanImageToVideoPipeline.from_pretrained(
37
  MODEL_ID,
38
  vae=vae,
39
  image_encoder=image_encoder,
40
  torch_dtype=DTYPE,
41
+ low_cpu_mem_usage=True, # lazy-load weights into CPU RAM
42
+ device_map="auto", # auto-split across CPU/GPU
43
  )
44
 
45
+ # 4) use the fast Rust-backed processor
46
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
47
  MODEL_ID, subfolder="image_processor", use_fast=True
48
  )
49
 
50
+ return pipe
 
 
 
 
51
 
52
  PIPE = load_pipeline()
53
 
54
  # ----------------------------------------------------------------------
55
  # UTILS ----------------------------------------------------------------
56
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
57
+ """Resize while keeping aspect and patch-size multiples."""
58
  ar = img.height / img.width
59
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
60
  h = round(np.sqrt(max_area * ar)) // mod * mod
 
62
  return img.resize((w, h), Image.LANCZOS), h, w
63
 
64
  def center_crop_resize(img: Image.Image, h, w):
65
+ """Center-crop & resize to target H×W."""
66
  ratio = max(w / img.width, h / img.height)
67
  img = img.resize(
68
+ (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
 
69
  )
70
  return TF.center_crop(img, [h, w])
71
 
 
74
  def generate(first_frame, last_frame, prompt, negative_prompt, steps,
75
  guidance, num_frames, seed, fps):
76
 
77
+ # handle seed
78
  if seed == -1:
79
  seed = torch.seed()
80
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
 
84
  if last_frame.size != first_frame.size:
85
  last_frame = center_crop_resize(last_frame, h, w)
86
 
87
+ # inference
88
  output = PIPE(
89
  image=first_frame,
90
  last_image=last_frame,
 
97
  guidance_scale=guidance,
98
  generator=gen,
99
  )
100
+ frames = output.frames[0] # list[PIL.Image]
101
 
102
+ # export to mp4
103
  video_path = export_to_video(frames, fps=fps)
104
  return video_path, seed
105
 
 
110
 
111
  with gr.Row():
112
  first_img = gr.Image(label="First frame", type="pil")
113
+ last_img = gr.Image(label="Last frame", type="pil")
114
 
115
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
116
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
117
 
118
  with gr.Accordion("Advanced parameters", open=False):
119
  steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")