Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from diffusers import AutoencoderKL | |
| import torchvision.transforms.v2 as transforms | |
| from torchvision.io import read_image | |
| from typing import Tuple, Dict, List | |
| import os | |
| from huggingface_hub import login | |
| # Get token from environment variable | |
| hf_token = os.getenv("HF_TOKEN") | |
| login(token=hf_token) | |
| class VAETester: | |
| def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"): | |
| self.device = device | |
| self.input_transform = transforms.Compose([ | |
| transforms.Pad(padding=[128, 0], padding_mode="edge"), | |
| transforms.Resize((512, 512), antialias=True), | |
| transforms.ToDtype(torch.float32, scale=True), | |
| transforms.Normalize(mean=[0.5], std=[0.5]), | |
| ]) | |
| self.base_transform = transforms.Compose([ | |
| transforms.Pad(padding=[128, 0], padding_mode="edge"), | |
| transforms.Resize((512, 512), antialias=True), | |
| transforms.ToDtype(torch.float32, scale=True), | |
| ]) | |
| self.output_transform = transforms.Normalize(mean=[-1], std=[2]) | |
| # Load all VAE models at initialization | |
| self.vae_models = self._load_all_vaes() | |
| def _load_all_vaes(self) -> Dict[str, AutoencoderKL]: | |
| """Load all available VAE models""" | |
| vae_configs = { | |
| "Stable Diffusion 3 Medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"), | |
| "Stable Diffusion v1-4": ("CompVis/stable-diffusion-v1-4", "vae"), | |
| "SD VAE FT-MSE": ("stabilityai/sd-vae-ft-mse", ""), | |
| "FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae") | |
| } | |
| vae_dict = {} | |
| for name, (path, subfolder) in vae_configs.items(): | |
| vae_dict[name] = AutoencoderKL.from_pretrained(path, subfolder=subfolder).to(self.device) | |
| return vae_dict | |
| def process_image(self, | |
| img: torch.Tensor, | |
| vae: AutoencoderKL, | |
| tolerance: float): | |
| """Process image through a single VAE""" | |
| img_transformed = self.input_transform(img).to(self.device).unsqueeze(0) | |
| original_base = self.base_transform(img).cpu() | |
| with torch.no_grad(): | |
| encoded = vae.encode(img_transformed).latent_dist.sample() | |
| encoded_scaled = encoded * vae.config.scaling_factor | |
| decoded = vae.decode(encoded_scaled / vae.config.scaling_factor).sample | |
| decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu() | |
| reconstructed = decoded_transformed.clip(0, 1) | |
| diff = (original_base - reconstructed).abs() | |
| bw_diff = (diff > tolerance).any(dim=0).float() | |
| diff_image = transforms.ToPILImage()(bw_diff) | |
| recon_image = transforms.ToPILImage()(reconstructed) | |
| diff_score = bw_diff.sum().item() | |
| return diff_image, recon_image, diff_score | |
| def process_all_models(self, | |
| img: torch.Tensor, | |
| tolerance: float): | |
| """Process image through all loaded VAEs""" | |
| results = {} | |
| for name, vae in self.vae_models.items(): | |
| diff_img, recon_img, score = self.process_image(img, vae, tolerance) | |
| results[name] = (diff_img, recon_img, score) | |
| return results | |
| # Initialize tester | |
| tester = VAETester() | |
| def test_all_vaes(image_path: str, tolerance: float): | |
| """Gradio interface function to test all VAEs""" | |
| try: | |
| img_tensor = read_image(image_path) | |
| results = tester.process_all_models(img_tensor, tolerance) | |
| diff_images = [] | |
| recon_images = [] | |
| scores = [] | |
| for name in tester.vae_models.keys(): | |
| diff_img, recon_img, score = results[name] | |
| diff_images.append(diff_img) | |
| recon_images.append(recon_img) | |
| scores.append(f"{name}: {score:.2f}") | |
| return diff_images, recon_images, scores | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| return [None], [None], [error_msg] | |
| # Gradio interface | |
| with gr.Blocks(title="VAE Performance Tester") as demo: | |
| gr.Markdown("# VAE Performance Testing Tool") | |
| gr.Markdown("Upload an image to compare all VAE models simultaneously") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="filepath", label="Input Image", height=512) | |
| tolerance_slider = gr.Slider( | |
| minimum=0.01, | |
| maximum=0.5, | |
| value=0.1, | |
| step=0.01, | |
| label="Difference Tolerance" | |
| ) | |
| submit_btn = gr.Button("Test All VAEs") | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512) | |
| recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512) | |
| scores_output = gr.Textbox(label="Difference Scores", lines=4) | |
| submit_btn.click( | |
| fn=test_all_vaes, | |
| inputs=[image_input, tolerance_slider], | |
| outputs=[diff_gallery, recon_gallery, scores_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |