Spaces:
Build error
Build error
| from typing import List | |
| from PIL import Image | |
| import numpy as np | |
| import math | |
| import random | |
| import cv2 | |
| from typing import List | |
| import torch | |
| import einops | |
| from pytorch_lightning import seed_everything | |
| from transparent_background import Remover | |
| from dataset.opencv_transforms.functional import to_tensor, center_crop | |
| from vtdm.model import create_model | |
| from vtdm.util import tensor2vid | |
| remover = Remover(jit=False) | |
| def pil_to_cv2(pil_image: Image.Image) -> np.ndarray: | |
| cv_image = np.array(pil_image) | |
| cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) | |
| return cv_image | |
| def prepare_white_image(input_image: Image.Image) -> Image.Image: | |
| # remove bg | |
| output = remover.process(input_image, type='rgba') | |
| # expand image | |
| width, height = output.size | |
| max_side = max(width, height) | |
| white_image = Image.new('RGBA', (max_side, max_side), (0, 0, 0, 0)) | |
| x_offset = (max_side - width) // 2 | |
| y_offset = (max_side - height) // 2 | |
| white_image.paste(output, (x_offset, y_offset)) | |
| return white_image | |
| class MultiViewGenerator: | |
| def __init__(self, checkpoint_path, config_path="inference.yaml"): | |
| self.models = {} | |
| denoising_model = create_model(config_path).cpu() | |
| denoising_model.init_from_ckpt(checkpoint_path) | |
| denoising_model = denoising_model.cuda().half() | |
| self.models["denoising_model"] = denoising_model | |
| def denoising(self, frames, args): | |
| with torch.no_grad(): | |
| C, T, H, W = frames.shape | |
| batch = {"video": frames.unsqueeze(0)} | |
| batch["elevation"] = ( | |
| torch.Tensor([args["elevation"]]).to(torch.int64).to(frames.device) | |
| ) | |
| batch["fps_id"] = torch.Tensor([7]).to(torch.int64).to(frames.device) | |
| batch["motion_bucket_id"] = ( | |
| torch.Tensor([127]).to(torch.int64).to(frames.device) | |
| ) | |
| batch = self.models["denoising_model"].add_custom_cond(batch, infer=True) | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| c, uc = self.models[ | |
| "denoising_model" | |
| ].conditioner.get_unconditional_conditioning( | |
| batch, | |
| force_uc_zero_embeddings=["cond_frames", "cond_frames_without_noise"], | |
| ) | |
| additional_model_inputs = { | |
| "image_only_indicator": torch.zeros(2, T).to( | |
| self.models["denoising_model"].device | |
| ), | |
| "num_video_frames": batch["num_video_frames"], | |
| } | |
| def denoiser(input, sigma, c): | |
| return self.models["denoising_model"].denoiser( | |
| self.models["denoising_model"].model, | |
| input, | |
| sigma, | |
| c, | |
| **additional_model_inputs | |
| ) | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| randn = torch.randn( | |
| [T, 4, H // 8, W // 8], device=self.models["denoising_model"].device | |
| ) | |
| samples = self.models["denoising_model"].sampler(denoiser, randn, cond=c, uc=uc) | |
| samples = self.models["denoising_model"].decode_first_stage(samples.half()) | |
| samples = einops.rearrange(samples, "(b t) c h w -> b c t h w", t=T) | |
| return tensor2vid(samples) | |
| def video_pipeline(self, frames, args) -> List[Image.Image]: | |
| num_iter = args["num_iter"] | |
| out_list = [] | |
| for _ in range(num_iter): | |
| with torch.no_grad(): | |
| results = self.denoising(frames, args) | |
| if len(out_list) == 0: | |
| out_list = out_list + results | |
| else: | |
| out_list = out_list + results[1:] | |
| img = out_list[-1] | |
| img = to_tensor(img) | |
| img = (img - 0.5) * 2.0 | |
| frames[:, 0] = img | |
| result = [] | |
| for i, frame in enumerate(out_list): | |
| input_image = Image.fromarray(frame) | |
| output_image = remover.process(input_image, type='rgba') | |
| result.append(output_image) | |
| return result | |
| def process(self, white_image: Image.Image, args) -> List[Image.Image]: | |
| img = pil_to_cv2(white_image) | |
| frame_list = [img] * args["clip_size"] | |
| h, w = frame_list[0].shape[0:2] | |
| rate = max( | |
| args["input_resolution"][0] * 1.0 / h, args["input_resolution"][1] * 1.0 / w | |
| ) | |
| frame_list = [ | |
| cv2.resize(f, [math.ceil(w * rate), math.ceil(h * rate)]) for f in frame_list | |
| ] | |
| frame_list = [ | |
| center_crop(f, [args["input_resolution"][0], args["input_resolution"][1]]) | |
| for f in frame_list | |
| ] | |
| frame_list = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frame_list] | |
| frame_list = [to_tensor(f) for f in frame_list] | |
| frame_list = [(f - 0.5) * 2.0 for f in frame_list] | |
| frames = torch.stack(frame_list, 1) | |
| frames = frames.cuda() | |
| self.models["denoising_model"].num_samples = args["clip_size"] | |
| self.models["denoising_model"].image_size = args["input_resolution"] | |
| return self.video_pipeline(frames, args) | |
| def infer(self, white_image: Image.Image) -> List[Image.Image]: | |
| seed = random.randint(0, 65535) | |
| seed_everything(seed) | |
| params = { | |
| "clip_size": 25, | |
| "input_resolution": [512, 512], | |
| "num_iter": 1, | |
| "aes": 6.0, | |
| "mv": [0.0, 0.0, 0.0, 10.0], | |
| "elevation": 0, | |
| } | |
| return self.process(white_image, params) | |