Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| import PIL.Image | |
| import torch | |
| from stable_diffusion_xl_partedit import PartEditPipeline, DotDictExtra, Binarization, PaddingStrategy, EmptyControl | |
| from diffusers import AutoencoderKL | |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
| from transformers import CLIPImageProcessor | |
| from huggingface_hub import hf_hub_download | |
| available_pts = [ | |
| "pt/torso_custom.pt", # this is human torso only | |
| "pt/chair_custom.pt", # this is seat of the chair only | |
| "pt/carhood_custom.pt", | |
| "pt/partimage_biped_head.pt", # this is essentially monkeys | |
| "pt/partimage_carbody.pt", # this is everything except the wheels | |
| "pt/partimage_human_hair.pt", | |
| "pt/partimage_human_head.pt", # this is essentially faces | |
| "pt/partimage_human_torso.pt", # use custom on in favour of this one | |
| "pt/partimage_quadruped_head.pt", # this is essentially animals on 4 legs | |
| ] | |
| def download_part(index): | |
| return hf_hub_download( | |
| repo_id="Aleksandar/PartEdit-extra", | |
| repo_type="dataset", | |
| filename=available_pts[index] | |
| ) | |
| PART_TOKENS = { | |
| "human_head": download_part(6), | |
| "human_hair": download_part(5), | |
| "human_torso_custom": download_part(0), # custom one | |
| "chair_custom": download_part(1), | |
| "carhood_custom": download_part(2), | |
| "carbody": download_part(4), | |
| "biped_head": download_part(8), | |
| "quadruped_head": download_part(3), | |
| "human_torso": download_part(7), # based on partimage | |
| } | |
| class PartEditSDXLModel: | |
| MAX_NUM_INFERENCE_STEPS = 50 | |
| def __init__(self): | |
| if torch.cuda.is_available(): | |
| self.device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu") | |
| self.sd_pipe, self.partedit_pipe = PartEditPipeline.default_pipeline(self.device) | |
| else: | |
| self.pipe = None | |
| def generate( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| seed: int = 0, | |
| eta: float = 0, | |
| ) -> PIL.Image.Image: | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("This demo does not work on CPU!") | |
| out = self.sd_pipe( | |
| prompt=prompt, | |
| # negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| eta=eta, | |
| generator=torch.Generator().manual_seed(seed), | |
| ).images[0] | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return out | |
| def edit( | |
| self, | |
| prompt: str, | |
| subject: str, | |
| part: str, | |
| edit: str, | |
| negative_prompt: str = "", | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| seed: int = 0, | |
| eta: int = 0, | |
| t_e: int = 50, | |
| ) -> PIL.Image.Image: | |
| # Sanity Checks | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("This demo does not work on CPU!") | |
| if part in PART_TOKENS: | |
| token_path = PART_TOKENS[part] | |
| else: | |
| raise ValueError(f"Part `{part}` is not supported!") | |
| if subject not in prompt: | |
| raise ValueError(f"The subject `{subject}` does not exist in the original prompt!") | |
| prompts = [ | |
| prompt, | |
| prompt.replace(subject, edit), | |
| ] | |
| # PartEdit Parameters | |
| cross_attention_kwargs = { | |
| "edit_type": "replace", | |
| "n_self_replace": 0.0, | |
| "n_cross_replace": {"default_": 1.0, edit: 0.4}, | |
| } | |
| extra_params = DotDictExtra() | |
| extra_params.update({"omega": 1.5, "edit_steps": t_e}) | |
| out = self.partedit_pipe( | |
| prompt=prompts, | |
| # negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| eta=eta, | |
| generator=torch.Generator().manual_seed(seed), | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| extra_kwargs=extra_params, | |
| embedding_opt=token_path, | |
| ).images[:2][::-1] | |
| mask = self.partedit_pipe.visualize_map_across_time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return out, mask | |