GeradeHouse commited on
Commit
7725ce2
·
verified ·
1 Parent(s): 5db3e50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -45
app.py CHANGED
@@ -1,10 +1,16 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
- Loads once, uses balanced device placement, streams high-level progress,
5
- and auto-offers the .mp4 for download.
 
 
 
6
  """
7
  import os
 
 
 
8
  import numpy as np
9
  import torch
10
  import gradio as gr
@@ -22,29 +28,26 @@ DTYPE = torch.float16
22
  MAX_AREA = 1280 * 720
23
  DEFAULT_FRAMES = 81
24
 
25
- # keep Hugging Face cache on disk so we don't re-download
26
- os.environ["HF_HOME"] = "/mnt/data/huggingface"
27
-
28
  # -----------------------------------------------------------------------------
29
- # PIPELINE LOADED ONCE
30
  # -----------------------------------------------------------------------------
31
  def load_pipeline():
32
- # 1) image encoder in full precision
33
  image_encoder = CLIPVisionModel.from_pretrained(
34
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
35
  )
36
- # 2) VAE in half precision (no slicing API here)
37
  vae = AutoencoderKLWan.from_pretrained(
38
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
39
  )
40
- # 3) load full pipeline balanced across GPU/CPU, with the fast processor
41
  pipe = WanImageToVideoPipeline.from_pretrained(
42
  MODEL_ID,
43
  image_encoder=image_encoder,
44
  vae=vae,
45
  torch_dtype=DTYPE,
46
- device_map="balanced", # spreads the model to fit your 24 GB
47
- use_fast=True, # get the fast CLIPImageProcessor internally
48
  )
49
  return pipe
50
 
@@ -52,58 +55,58 @@ PIPE = load_pipeline()
52
 
53
 
54
  # -----------------------------------------------------------------------------
55
- # UTILS
56
  # -----------------------------------------------------------------------------
57
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
58
- ar = img.height / img.width
59
- mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
60
- h = int(np.sqrt(max_area * ar)) // mod * mod
61
- w = int(np.sqrt(max_area / ar)) // mod * mod
62
  return img.resize((w, h), Image.LANCZOS), h, w
63
 
64
  def center_crop_resize(img: Image.Image, h, w):
65
  ratio = max(w / img.width, h / img.height)
66
- img = img.resize(
67
  (round(img.width * ratio), round(img.height * ratio)),
68
  Image.LANCZOS
69
  )
70
- return TF.center_crop(img, [h, w])
71
 
72
 
73
  # -----------------------------------------------------------------------------
74
- # GENERATE WITH STREAMING PROGRESS
75
  # -----------------------------------------------------------------------------
76
  def generate(
77
- first_frame: Image.Image,
78
- last_frame: Image.Image,
79
- prompt: str,
80
- negative: str,
81
- steps: int,
82
- guidance: float,
83
- num_frames: int,
84
- seed: int,
85
- fps: int,
86
- progress= gr.Progress(),
87
  ):
88
- # seed
89
  if seed == -1:
90
  seed = torch.seed()
91
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
92
 
93
  # 0–15%: resize
94
  progress(0.0, desc="Resizing first frame…")
95
- first_resized, h, w = aspect_resize(first_frame)
96
- if last_frame.size != first_resized.size:
97
  progress(0.15, desc="Resizing last frame…")
98
- last_resized = center_crop_resize(last_frame, h, w)
99
  else:
100
- last_resized = first_resized
101
 
102
- # 15–25%: warm up
103
- progress(0.25, desc="Initializing pipeline…")
104
  out = PIPE(
105
- image=first_resized,
106
- last_image=last_resized,
107
  prompt=prompt,
108
  negative_prompt=negative or None,
109
  height=h,
@@ -114,12 +117,11 @@ def generate(
114
  generator=gen,
115
  )
116
 
117
- # 2590%: inference happens inside the pipeline (console shows bars)
118
- progress(0.90, desc="Exporting video…")
119
  video_path = export_to_video(out.frames[0], fps=fps)
120
-
121
- # done
122
  progress(1.0, desc="Done!")
 
123
  return video_path, seed
124
 
125
 
@@ -134,14 +136,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
  last_img = gr.Image(label="Last frame", type="pil")
135
 
136
  prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
137
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="blurry, low-res")
138
 
139
  with gr.Accordion("Advanced parameters", open=False):
140
  steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
141
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
142
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
143
  fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
144
- seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
145
 
146
  run_btn = gr.Button("Generate")
147
  download = gr.File(label="Download .mp4", interactive=False)
@@ -154,5 +156,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
154
  outputs=[ download, seed_used ],
155
  )
156
 
157
- # serialize tasks with a mini progress badge
158
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
+ Single global load (no repeated downloads)
5
+ Balanced device_map to avoid OOM on 24 GB A10
6
+ • Fast CLIP processor via use_fast=True
7
+ • High-level streaming progress
8
+ • Auto-download via gr.File
9
  """
10
  import os
11
+ # persist Hugging Face cache so safetensors only download once
12
+ os.environ["HF_HOME"] = "/mnt/data/huggingface"
13
+
14
  import numpy as np
15
  import torch
16
  import gradio as gr
 
28
  MAX_AREA = 1280 * 720
29
  DEFAULT_FRAMES = 81
30
 
 
 
 
31
  # -----------------------------------------------------------------------------
32
+ # LOAD PIPELINE ONCE
33
  # -----------------------------------------------------------------------------
34
  def load_pipeline():
35
+ # 1) CLIP image encoder (fp32)
36
  image_encoder = CLIPVisionModel.from_pretrained(
37
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
38
  )
39
+ # 2) VAE (fp16)
40
  vae = AutoencoderKLWan.from_pretrained(
41
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
42
  )
43
+ # 3) Balanced device placement + fast processor
44
  pipe = WanImageToVideoPipeline.from_pretrained(
45
  MODEL_ID,
46
  image_encoder=image_encoder,
47
  vae=vae,
48
  torch_dtype=DTYPE,
49
+ device_map="balanced", # spread weights CPU↔GPU
50
+ use_fast=True, # internal fast CLIPImageProcessor
51
  )
52
  return pipe
53
 
 
55
 
56
 
57
  # -----------------------------------------------------------------------------
58
+ # HELPERS
59
  # -----------------------------------------------------------------------------
60
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
61
+ ar = img.height / img.width
62
+ mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
63
+ h = int(np.sqrt(max_area * ar)) // mod * mod
64
+ w = int(np.sqrt(max_area / ar)) // mod * mod
65
  return img.resize((w, h), Image.LANCZOS), h, w
66
 
67
  def center_crop_resize(img: Image.Image, h, w):
68
  ratio = max(w / img.width, h / img.height)
69
+ img2 = img.resize(
70
  (round(img.width * ratio), round(img.height * ratio)),
71
  Image.LANCZOS
72
  )
73
+ return TF.center_crop(img2, [h, w])
74
 
75
 
76
  # -----------------------------------------------------------------------------
77
+ # GENERATION + STREAMING
78
  # -----------------------------------------------------------------------------
79
  def generate(
80
+ first_frame: Image.Image,
81
+ last_frame: Image.Image,
82
+ prompt: str,
83
+ negative: str,
84
+ steps: int,
85
+ guidance: float,
86
+ num_frames: int,
87
+ seed: int,
88
+ fps: int,
89
+ progress= gr.Progress(),
90
  ):
91
+ # choose seed
92
  if seed == -1:
93
  seed = torch.seed()
94
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
95
 
96
  # 0–15%: resize
97
  progress(0.0, desc="Resizing first frame…")
98
+ f_resized, h, w = aspect_resize(first_frame)
99
+ if last_frame.size != f_resized.size:
100
  progress(0.15, desc="Resizing last frame…")
101
+ l_resized = center_crop_resize(last_frame, h, w)
102
  else:
103
+ l_resized = f_resized
104
 
105
+ # 15–25%: spin up pipeline
106
+ progress(0.25, desc="Launching inference…")
107
  out = PIPE(
108
+ image=f_resized,
109
+ last_image=l_resized,
110
  prompt=prompt,
111
  negative_prompt=negative or None,
112
  height=h,
 
117
  generator=gen,
118
  )
119
 
120
+ # 90100%: export
121
+ progress(0.90, desc="Building video file…")
122
  video_path = export_to_video(out.frames[0], fps=fps)
 
 
123
  progress(1.0, desc="Done!")
124
+
125
  return video_path, seed
126
 
127
 
 
136
  last_img = gr.Image(label="Last frame", type="pil")
137
 
138
  prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
139
+ negative = gr.Textbox(label="Negative prompt (opt)", placeholder="blurry, lowres")
140
 
141
  with gr.Accordion("Advanced parameters", open=False):
142
  steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
143
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
144
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
145
  fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
146
+ seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=rand)")
147
 
148
  run_btn = gr.Button("Generate")
149
  download = gr.File(label="Download .mp4", interactive=False)
 
156
  outputs=[ download, seed_used ],
157
  )
158
 
 
159
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)