Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
57b4b9a
1
Parent(s):
08d8dcb
update
Browse files
app.py
CHANGED
|
@@ -24,91 +24,95 @@ access_token = os.getenv("HF_TOKEN")
|
|
| 24 |
fitdit_repo = "BoyuanJiang/FitDiT"
|
| 25 |
repo_path = snapshot_download(repo_id=fitdit_repo, use_auth_token=access_token)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
|
| 113 |
|
| 114 |
def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
|
|
@@ -185,8 +189,7 @@ FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers
|
|
| 185 |
If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
|
| 186 |
"""
|
| 187 |
|
| 188 |
-
def create_demo(
|
| 189 |
-
generator = FitDiTGenerator(model_path, device, with_fp16)
|
| 190 |
with gr.Blocks(title="FitDiT") as demo:
|
| 191 |
gr.Markdown(HEADER)
|
| 192 |
with gr.Row():
|
|
@@ -294,15 +297,10 @@ def create_demo(model_path, device, with_fp16):
|
|
| 294 |
|
| 295 |
ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
|
| 296 |
ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
|
| 297 |
-
run_mask_button.click(fn=
|
| 298 |
-
run_button.click(fn=
|
| 299 |
return demo
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|
| 302 |
-
|
| 303 |
-
parser = argparse.ArgumentParser(description="FitDiT")
|
| 304 |
-
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
|
| 305 |
-
parser.add_argument("--fp16", action="store_true", help="Load model with fp16, default is bf16")
|
| 306 |
-
args = parser.parse_args()
|
| 307 |
-
demo = create_demo(repo_path, args.device, args.fp16)
|
| 308 |
demo.launch()
|
|
|
|
| 24 |
fitdit_repo = "BoyuanJiang/FitDiT"
|
| 25 |
repo_path = snapshot_download(repo_id=fitdit_repo, use_auth_token=access_token)
|
| 26 |
|
| 27 |
+
weight_dtype = torch.bfloat16
|
| 28 |
+
device = "cuda"
|
| 29 |
+
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
|
| 30 |
+
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
|
| 31 |
+
pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
|
| 32 |
+
pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
|
| 33 |
+
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
|
| 34 |
+
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
|
| 35 |
+
pose_guider.to(device=device, dtype=weight_dtype)
|
| 36 |
+
image_encoder_large.to(device=device)
|
| 37 |
+
image_encoder_bigG.to(device=device)
|
| 38 |
+
pipeline = StableDiffusion3TryOnPipeline.from_pretrained(repo_path, torch_dtype=weight_dtype, \
|
| 39 |
+
transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, \
|
| 40 |
+
image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
|
| 41 |
+
pipeline.to(device)
|
| 42 |
+
dwprocessor = DWposeDetector(model_root=repo_path, device=device)
|
| 43 |
+
parsing_model = Parsing(model_root=repo_path, device=device)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
|
| 48 |
+
def generate_mask(vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
|
| 49 |
+
with torch.inference_mode():
|
| 50 |
+
vton_img = Image.open(vton_img)
|
| 51 |
+
vton_img_det = resize_image(vton_img)
|
| 52 |
+
pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
|
| 53 |
+
candidate[candidate<0]=0
|
| 54 |
+
candidate = candidate[0]
|
| 55 |
|
| 56 |
+
candidate[:, 0]*=vton_img_det.width
|
| 57 |
+
candidate[:, 1]*=vton_img_det.height
|
| 58 |
|
| 59 |
+
pose_image = pose_image[:,:,::-1] #rgb
|
| 60 |
+
pose_image = Image.fromarray(pose_image)
|
| 61 |
+
model_parse, _ = parsing_model(vton_img_det)
|
| 62 |
|
| 63 |
+
mask, mask_gray = get_mask_location(category, model_parse, \
|
| 64 |
+
candidate, model_parse.width, model_parse.height, \
|
| 65 |
+
offset_top, offset_bottom, offset_left, offset_right)
|
| 66 |
+
mask = mask.resize(vton_img.size)
|
| 67 |
+
mask_gray = mask_gray.resize(vton_img.size)
|
| 68 |
+
mask = mask.convert("L")
|
| 69 |
+
mask_gray = mask_gray.convert("L")
|
| 70 |
+
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
|
| 71 |
|
| 72 |
+
im = {}
|
| 73 |
+
im['background'] = np.array(vton_img.convert("RGBA"))
|
| 74 |
+
im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
|
| 75 |
+
im['composite'] = np.array(masked_vton_img.convert("RGBA"))
|
| 76 |
+
|
| 77 |
+
return im, pose_image
|
| 78 |
|
| 79 |
+
@spaces.GPU
|
| 80 |
+
def process(vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
|
| 81 |
+
assert resolution in ["768x1024", "1152x1536", "1536x2048"]
|
| 82 |
+
new_width, new_height = resolution.split("x")
|
| 83 |
+
new_width = int(new_width)
|
| 84 |
+
new_height = int(new_height)
|
| 85 |
+
with torch.inference_mode():
|
| 86 |
+
garm_img = Image.open(garm_img)
|
| 87 |
+
vton_img = Image.open(vton_img)
|
| 88 |
|
| 89 |
+
model_image_size = vton_img.size
|
| 90 |
+
garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
|
| 91 |
+
vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
|
| 92 |
|
| 93 |
+
mask = pre_mask["layers"][0][:,:,3]
|
| 94 |
+
mask = Image.fromarray(mask)
|
| 95 |
+
mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
|
| 96 |
+
mask = mask.convert("L")
|
| 97 |
+
pose_image = Image.fromarray(pose_image)
|
| 98 |
+
pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
|
| 99 |
+
if seed==-1:
|
| 100 |
+
seed = random.randint(0, 2147483647)
|
| 101 |
+
res = pipeline(
|
| 102 |
+
height=new_height,
|
| 103 |
+
width=new_width,
|
| 104 |
+
guidance_scale=image_scale,
|
| 105 |
+
num_inference_steps=n_steps,
|
| 106 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
| 107 |
+
cloth_image=garm_img,
|
| 108 |
+
model_image=vton_img,
|
| 109 |
+
mask=mask,
|
| 110 |
+
pose_image=pose_image,
|
| 111 |
+
num_images_per_prompt=num_images_per_prompt
|
| 112 |
+
).images
|
| 113 |
+
for idx in range(len(res)):
|
| 114 |
+
res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
|
| 115 |
+
return res
|
| 116 |
|
| 117 |
|
| 118 |
def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
|
|
|
|
| 189 |
If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
|
| 190 |
"""
|
| 191 |
|
| 192 |
+
def create_demo():
|
|
|
|
| 193 |
with gr.Blocks(title="FitDiT") as demo:
|
| 194 |
gr.Markdown(HEADER)
|
| 195 |
with gr.Row():
|
|
|
|
| 297 |
|
| 298 |
ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
|
| 299 |
ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
|
| 300 |
+
run_mask_button.click(fn=generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
|
| 301 |
+
run_button.click(fn=process, inputs=ips2, outputs=[result_gallery])
|
| 302 |
return demo
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
| 305 |
+
demo = create_demo()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
demo.launch()
|