Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| from diffusers import AutoencoderKL, AutoencoderDC, AutoModel | |
| import torchvision.transforms.v2 as transforms | |
| from torchvision.io import read_image | |
| from typing import Dict | |
| import os | |
| import time | |
| from huggingface_hub import login | |
| # Get token from environment variable | |
| hf_token = os.getenv("access_token") | |
| login(token=hf_token) | |
| class PadToSquare: | |
| """Custom transform to pad an image to square dimensions""" | |
| def __call__(self, img): | |
| _, h, w = img.shape # Get the original dimensions | |
| max_side = max(h, w) | |
| pad_h = (max_side - h) // 2 | |
| pad_w = (max_side - w) // 2 | |
| padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h) | |
| return transforms.functional.pad(img, padding, padding_mode="edge") | |
| class VAETester: | |
| def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512): | |
| self.device = device | |
| self.input_transform = transforms.Compose([ | |
| PadToSquare(), | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToDtype(torch.float32, scale=True), | |
| transforms.Normalize(mean=[0.5], std=[0.5]), | |
| ]) | |
| self.base_transform = transforms.Compose([ | |
| PadToSquare(), | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToDtype(torch.float32, scale=True), | |
| ]) | |
| self.output_transform = transforms.Normalize(mean=[-1], std=[2]) | |
| self.vae_models = self._load_all_vaes() | |
| def _load_all_vaes(self) -> Dict[str, Dict]: | |
| """Load configurations for all VAE models""" | |
| vaes = { | |
| "stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device), | |
| "eq-vae-ema": AutoencoderKL.from_pretrained("zelaki/eq-vae-ema").to(self.device), | |
| "eq-sdxl-vae": AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE").to(self.device), | |
| "sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device), | |
| "sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device), | |
| "stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device), | |
| "FLUX.1": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device), | |
| "CogView4-6B": AutoencoderKL.from_pretrained("THUDM/CogView4-6B", subfolder="vae").to(self.device), | |
| "playground-v2.5": AutoencoderKL.from_pretrained("playgroundai/playground-v2.5-1024px-aesthetic", subfolder="vae").to(self.device), | |
| # "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device), | |
| "FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device), | |
| "FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device), | |
| "FLUX.2-TinyAutoEncoder": AutoModel.from_pretrained("fal/FLUX.2-Tiny-AutoEncoder", trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device), | |
| } | |
| # Define the desired order of models | |
| order = [ | |
| "stable-diffusion-v1-4", | |
| "eq-vae-ema", | |
| "eq-sdxl-vae", | |
| "sd-vae-ft-mse", | |
| "sdxl-vae", | |
| "playground-v2.5", | |
| "stable-diffusion-3-medium", | |
| "FLUX.1", | |
| "CogView4-6B", | |
| # "dc-ae-f32c32-sana-1.0", | |
| "FLUX.1-Kontext", | |
| "FLUX.2", | |
| "FLUX.2-TinyAutoEncoder", | |
| ] | |
| # Construct the vae_models dictionary in the specified order | |
| return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order} | |
| def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float, vae_name: str): | |
| """Process image through a single VAE model""" | |
| dtype = model_config["dtype"] | |
| vae = model_config["vae"] | |
| img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0) | |
| original_base = self.base_transform(img).cpu() | |
| # Time the encoding-decoding process | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| if vae_name == "FLUX.2-TinyAutoEncoder": | |
| encoded = vae.encode(img_transformed, return_dict=False) | |
| decoded = vae.decode(encoded, return_dict=False) | |
| else: | |
| encoded = vae.encode(img_transformed).latent_dist.sample() | |
| decoded = vae.decode(encoded).sample | |
| processing_time = time.time() - start_time | |
| decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu() | |
| reconstructed = decoded_transformed.clip(0, 1) | |
| diff = (original_base - reconstructed).abs() | |
| bw_diff = (diff > tolerance).any(dim=0).float() | |
| diff_image = transforms.ToPILImage()(bw_diff) | |
| recon_image = transforms.ToPILImage()(reconstructed) | |
| diff_score = bw_diff.sum().item() | |
| return diff_image, recon_image, diff_score, processing_time | |
| def process_all_models(self, img: torch.Tensor, tolerance: float): | |
| """Process image through all configured VAEs""" | |
| results = {} | |
| for vae_name, model_config in self.vae_models.items(): | |
| results[vae_name] = self.process_image(img, model_config, tolerance, vae_name) | |
| return results | |
| def test_all_vaes(image_path: str, tolerance: float, img_size: int): | |
| """Gradio interface function to test all VAEs""" | |
| tester = VAETester(img_size=img_size) | |
| try: | |
| img_tensor = read_image(image_path) | |
| results = tester.process_all_models(img_tensor, tolerance) | |
| diff_images = [] | |
| recon_images = [] | |
| scores = [] | |
| for name in tester.vae_models.keys(): | |
| diff_img, recon_img, score, proc_time = results[name] | |
| diff_images.append((diff_img, name)) | |
| recon_images.append((recon_img, name)) | |
| scores.append(f"{name:<25}: {score:7,.0f} | {proc_time:.4f}s") | |
| return diff_images, recon_images, "\n".join(scores) | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| return [None], [None], error_msg | |
| examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))] | |
| custom_css = """ | |
| .center-header { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| margin: 0 0 10px 0; | |
| } | |
| .monospace-text { | |
| font-family: 'Courier New', Courier, monospace; | |
| } | |
| """ | |
| with gr.Blocks(title="VAE Performance Tester", css=custom_css) as demo: | |
| gr.Markdown("<div class='center-header'><h1>VAE Comparison Tool</h1></div>") | |
| gr.Markdown(""" | |
| Upload an image or select an example to compare how different VAEs reconstruct it. | |
| 1. The image is padded to a square and resized to the selected size (512 or 1024 pixels). | |
| 2. Each VAE encodes the image into a latent space and decodes it back. | |
| 3. Outputs include: | |
| - **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance). | |
| - **Reconstructed Images**: Outputs from each VAE. | |
| - **Sum of Differences and Time**: Total pixels exceeding tolerance (lower is better) and processing time in seconds. | |
| Adjust tolerance to change sensitivity. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="filepath", label="Input Image", height=512) | |
| tolerance_slider = gr.Slider( | |
| minimum=0.01, | |
| maximum=0.5, | |
| value=0.1, | |
| step=0.01, | |
| label="Difference Tolerance", | |
| info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged." | |
| ) | |
| img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512) | |
| submit_btn = gr.Button("Test All VAEs") | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512) | |
| recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512) | |
| scores_output = gr.Textbox(label="Sum of differences (lower is better) | Processing time (lower is faster)", lines=12, elem_classes="monospace-text") | |
| if examples: | |
| with gr.Row(): | |
| gr.Examples(examples=examples, inputs=image_input, label="Example Images") | |
| submit_btn.click( | |
| fn=test_all_vaes, | |
| inputs=[image_input, tolerance_slider, img_size], | |
| outputs=[diff_gallery, recon_gallery, scores_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, ssr_mode=False) | |