Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,313 Bytes
16eb15e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gc
import PIL.Image
import torch
from stable_diffusion_xl_partedit import PartEditPipeline, DotDictExtra, Binarization, PaddingStrategy, EmptyControl
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
available_pts = [
"pt/torso_custom.pt", # this is human torso only
"pt/chair_custom.pt", # this is seat of the chair only
"pt/carhood_custom.pt",
"pt/partimage_biped_head.pt", # this is essentially monkeys
"pt/partimage_carbody.pt", # this is everything except the wheels
"pt/partimage_human_hair.pt",
"pt/partimage_human_head.pt", # this is essentially faces
"pt/partimage_human_torso.pt", # use custom on in favour of this one
"pt/partimage_quadruped_head.pt", # this is essentially animals on 4 legs
]
def download_part(index):
return hf_hub_download(
repo_id="Aleksandar/PartEdit-extra",
repo_type="dataset",
filename=available_pts[index]
)
PART_TOKENS = {
"human_head": download_part(6),
"human_hair": download_part(5),
"human_torso_custom": download_part(0), # custom one
"chair_custom": download_part(1),
"carhood_custom": download_part(2),
"carbody": download_part(4),
"biped_head": download_part(8),
"quadruped_head": download_part(3),
"human_torso": download_part(7), # based on partimage
}
class PartEditSDXLModel:
MAX_NUM_INFERENCE_STEPS = 50
def __init__(self):
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
self.sd_pipe, self.partedit_pipe = PartEditPipeline.default_pipeline(self.device)
else:
self.pipe = None
def generate(
self,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
seed: int = 0,
eta: float = 0,
) -> PIL.Image.Image:
if not torch.cuda.is_available():
raise RuntimeError("This demo does not work on CPU!")
out = self.sd_pipe(
prompt=prompt,
# negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=eta,
generator=torch.Generator().manual_seed(seed),
).images[0]
gc.collect()
torch.cuda.empty_cache()
return out
def edit(
self,
prompt: str,
subject: str,
part: str,
edit: str,
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
seed: int = 0,
eta: int = 0,
t_e: int = 50,
) -> PIL.Image.Image:
# Sanity Checks
if not torch.cuda.is_available():
raise RuntimeError("This demo does not work on CPU!")
if part in PART_TOKENS:
token_path = PART_TOKENS[part]
else:
raise ValueError(f"Part `{part}` is not supported!")
if subject not in prompt:
raise ValueError(f"The subject `{subject}` does not exist in the original prompt!")
prompts = [
prompt,
prompt.replace(subject, edit),
]
# PartEdit Parameters
cross_attention_kwargs = {
"edit_type": "replace",
"n_self_replace": 0.0,
"n_cross_replace": {"default_": 1.0, edit: 0.4},
}
extra_params = DotDictExtra()
extra_params.update({"omega": 1.5, "edit_steps": t_e})
out = self.partedit_pipe(
prompt=prompts,
# negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=eta,
generator=torch.Generator().manual_seed(seed),
cross_attention_kwargs=cross_attention_kwargs,
extra_kwargs=extra_params,
embedding_opt=token_path,
).images[:2][::-1]
mask = self.partedit_pipe.visualize_map_across_time()
gc.collect()
torch.cuda.empty_cache()
return out, mask
|