File size: 9,021 Bytes
e354de8
2a90016
 
331d5ce
2a90016
 
3050d2d
0dc0490
a766367
0dc0490
c02cab9
2a90016
0dc0490
84e7576
0dc0490
2a90016
3050d2d
 
 
 
 
 
 
 
 
 
2a90016
5558320
2a90016
 
3050d2d
5558320
2a90016
 
 
 
3050d2d
5558320
2a90016
 
 
 
 
46241ec
331d5ce
 
46241ec
7069342
 
46241ec
 
 
a766367
7069342
778f222
16ab8a3
984afb9
eea9702
05d50b7
46241ec
 
 
 
7069342
 
46241ec
 
778f222
46241ec
a766367
 
16ab8a3
984afb9
eea9702
05d50b7
46241ec
 
 
331d5ce
46241ec
05d50b7
331d5ce
 
05d50b7
331d5ce
2a90016
 
05d50b7
a766367
331d5ce
05d50b7
331d5ce
 
 
46241ec
 
a766367
 
331d5ce
2a90016
 
 
 
 
 
a766367
2a90016
46241ec
 
2a90016
05d50b7
 
2a90016
 
984afb9
5558320
2a90016
5558320
2a90016
 
 
 
 
 
 
 
 
a766367
3050d2d
 
a766367
2a90016
3050d2d
2a90016
 
3050d2d
2a90016
3050d2d
eea9702
 
 
 
 
 
 
 
 
 
 
 
 
 
3050d2d
a766367
46241ec
a766367
46241ec
 
 
a766367
46241ec
3050d2d
2a90016
 
 
 
 
 
 
 
 
3050d2d
46241ec
5558320
46241ec
2a90016
 
 
 
 
 
984afb9
3050d2d
08198f0
 
46241ec
2a90016
 
 
5558320
2a90016
 
 
 
7069342
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import spaces
import gradio as gr
import torch
from diffusers import AutoencoderKL, AutoencoderDC, AutoModel
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
from typing import Dict
import os
import time
from huggingface_hub import login


# Get token from environment variable
hf_token = os.getenv("access_token")
login(token=hf_token)

class PadToSquare:
    """Custom transform to pad an image to square dimensions"""
    def __call__(self, img):
        _, h, w = img.shape  # Get the original dimensions
        max_side = max(h, w)
        pad_h = (max_side - h) // 2
        pad_w = (max_side - w) // 2
        padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
        return transforms.functional.pad(img, padding, padding_mode="edge")

class VAETester:
    def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512):
        self.device = device
        self.input_transform = transforms.Compose([
            PadToSquare(),
            transforms.Resize((img_size, img_size)),
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Normalize(mean=[0.5], std=[0.5]),
        ])
        self.base_transform = transforms.Compose([
            PadToSquare(),
            transforms.Resize((img_size, img_size)),
            transforms.ToDtype(torch.float32, scale=True),
        ])
        self.output_transform = transforms.Normalize(mean=[-1], std=[2])
        self.vae_models = self._load_all_vaes()

    def _load_all_vaes(self) -> Dict[str, Dict]:
        """Load configurations for all VAE models"""
        vaes = {
            "stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device),
            "eq-vae-ema": AutoencoderKL.from_pretrained("zelaki/eq-vae-ema").to(self.device),
            "eq-sdxl-vae": AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE").to(self.device),            
            "sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
            "sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
            "stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
            "FLUX.1": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
            "CogView4-6B": AutoencoderKL.from_pretrained("THUDM/CogView4-6B", subfolder="vae").to(self.device),
            "playground-v2.5": AutoencoderKL.from_pretrained("playgroundai/playground-v2.5-1024px-aesthetic", subfolder="vae").to(self.device),
			# "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device),
            "FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device),
            "FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device),
            "FLUX.2-TinyAutoEncoder": AutoModel.from_pretrained("fal/FLUX.2-Tiny-AutoEncoder", trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device),
        }
        # Define the desired order of models
        order = [
            "stable-diffusion-v1-4",
            "eq-vae-ema",
            "eq-sdxl-vae",        
            "sd-vae-ft-mse",
            "sdxl-vae",
            "playground-v2.5",
            "stable-diffusion-3-medium",
            "FLUX.1",
            "CogView4-6B",
			# "dc-ae-f32c32-sana-1.0",
            "FLUX.1-Kontext",
            "FLUX.2",
            "FLUX.2-TinyAutoEncoder",
        ]

        # Construct the vae_models dictionary in the specified order
        return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order}

    def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float, vae_name: str):
        """Process image through a single VAE model"""
        dtype = model_config["dtype"]
        vae = model_config["vae"]
        img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0)
        original_base = self.base_transform(img).cpu()

        # Time the encoding-decoding process
        start_time = time.time()
        with torch.no_grad():
            if vae_name == "FLUX.2-TinyAutoEncoder":
                encoded = vae.encode(img_transformed, return_dict=False)
                decoded = vae.decode(encoded, return_dict=False)
            else:
                encoded = vae.encode(img_transformed).latent_dist.sample()
                decoded = vae.decode(encoded).sample
        processing_time = time.time() - start_time

        decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu()
        reconstructed = decoded_transformed.clip(0, 1)
        diff = (original_base - reconstructed).abs()
        bw_diff = (diff > tolerance).any(dim=0).float()
        diff_image = transforms.ToPILImage()(bw_diff)
        recon_image = transforms.ToPILImage()(reconstructed)
        diff_score = bw_diff.sum().item()
        return diff_image, recon_image, diff_score, processing_time

    def process_all_models(self, img: torch.Tensor, tolerance: float):
        """Process image through all configured VAEs"""
        results = {}
        for vae_name, model_config in self.vae_models.items():
            results[vae_name] = self.process_image(img, model_config, tolerance, vae_name)
        return results

@spaces.GPU(duration=20)
def test_all_vaes(image_path: str, tolerance: float, img_size: int):
    """Gradio interface function to test all VAEs"""
    tester = VAETester(img_size=img_size)
    try:
        img_tensor = read_image(image_path)
        results = tester.process_all_models(img_tensor, tolerance)

        diff_images = []
        recon_images = []
        scores = []

        for name in tester.vae_models.keys():
            diff_img, recon_img, score, proc_time = results[name]
            diff_images.append((diff_img, name))
            recon_images.append((recon_img, name))
            scores.append(f"{name:<25}: {score:7,.0f} | {proc_time:.4f}s")

        return diff_images, recon_images, "\n".join(scores)
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        return [None], [None], error_msg

examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]
custom_css = """
    .center-header {
        display: flex;
        align-items: center;
        justify-content: center;
        margin: 0 0 10px 0;
    }
    .monospace-text {
        font-family: 'Courier New', Courier, monospace;
    }
"""

with gr.Blocks(title="VAE Performance Tester", css=custom_css) as demo:
    gr.Markdown("<div class='center-header'><h1>VAE Comparison Tool</h1></div>")
    gr.Markdown("""
        Upload an image or select an example to compare how different VAEs reconstruct it.
        1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
        2. Each VAE encodes the image into a latent space and decodes it back.
        3. Outputs include:
           - **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
           - **Reconstructed Images**: Outputs from each VAE.
           - **Sum of Differences and Time**: Total pixels exceeding tolerance (lower is better) and processing time in seconds.
        Adjust tolerance to change sensitivity.
    """)

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="filepath", label="Input Image", height=512)
            tolerance_slider = gr.Slider(
                minimum=0.01,
                maximum=0.5,
                value=0.1,
                step=0.01,
                label="Difference Tolerance",
                info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged."
            )
            img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512)
            submit_btn = gr.Button("Test All VAEs")

        with gr.Column(scale=3):
            with gr.Row():
                diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
                recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
            scores_output = gr.Textbox(label="Sum of differences (lower is better) | Processing time (lower is faster)", lines=12, elem_classes="monospace-text")

    if examples:
        with gr.Row():
            gr.Examples(examples=examples, inputs=image_input, label="Example Images")

    submit_btn.click(
        fn=test_all_vaes,
        inputs=[image_input, tolerance_slider, img_size],
        outputs=[diff_gallery, recon_gallery, scores_output]
    )

if __name__ == "__main__":
    demo.launch(share=False, ssr_mode=False)