rizavelioglu commited on
Commit
331d5ce
·
1 Parent(s): 9175dc1

- remove remote-VAE support

Browse files
Files changed (1) hide show
  1. app.py +16 -49
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
4
- from diffusers import AutoencoderKL, AutoencoderDC
5
- from diffusers.utils.remote_utils import remote_decode
6
  import torchvision.transforms.v2 as transforms
7
  from torchvision.io import read_image
8
  from typing import Dict
@@ -42,18 +41,9 @@ class VAETester:
42
  self.output_transform = transforms.Normalize(mean=[-1], std=[2])
43
  self.vae_models = self._load_all_vaes()
44
 
45
- def _get_endpoint(self, base_name: str) -> str:
46
- """Helper method to get the endpoint for a given base model name"""
47
- endpoints = {
48
- "sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
49
- "sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
50
- "FLUX.1": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
51
- }
52
- return endpoints[base_name]
53
-
54
  def _load_all_vaes(self) -> Dict[str, Dict]:
55
- """Load configurations for local and remote VAE models"""
56
- local_vaes = {
57
  "stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device),
58
  "eq-vae-ema": AutoencoderKL.from_pretrained("zelaki/eq-vae-ema").to(self.device),
59
  "eq-sdxl-vae": AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE").to(self.device),
@@ -66,6 +56,7 @@ class VAETester:
66
  # "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device),
67
  "FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device),
68
  "FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device),
 
69
  }
70
  # Define the desired order of models
71
  order = [
@@ -73,66 +64,42 @@ class VAETester:
73
  "eq-vae-ema",
74
  "eq-sdxl-vae",
75
  "sd-vae-ft-mse",
76
- #"sd-vae-ft-mse (remote)",
77
  "sdxl-vae",
78
- #"sdxl-vae (remote)",
79
  "playground-v2.5",
80
  "stable-diffusion-3-medium",
81
  "FLUX.1",
82
- #"FLUX.1 (remote)",
83
  "CogView4-6B",
84
  # "dc-ae-f32c32-sana-1.0",
85
  "FLUX.1-Kontext",
86
  "FLUX.2",
 
87
  ]
88
 
89
  # Construct the vae_models dictionary in the specified order
90
- vae_models = {}
91
- for name in order:
92
- if "(remote)" not in name:
93
- # Local model
94
- vae_models[name] = {"type": "local", "vae": local_vaes[name]}
95
- else:
96
- # Remote model
97
- base_name = name.replace(" (remote)", "")
98
- vae_models[name] = {
99
- "type": "remote",
100
- "local_vae_key": base_name,
101
- "endpoint": self._get_endpoint(base_name),
102
- }
103
-
104
- return vae_models
105
 
106
  def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
107
- """Process image through a single VAE (local or remote)"""
108
- img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
 
109
  original_base = self.base_transform(img).cpu()
110
 
111
  # Start timer
112
  start_time = time.time()
113
 
114
- if model_config["type"] == "local":
115
- vae = model_config["vae"]
116
- with torch.no_grad():
 
 
 
117
  encoded = vae.encode(img_transformed).latent_dist.sample()
118
  decoded = vae.decode(encoded).sample
119
- elif model_config["type"] == "remote":
120
- local_vae = self.vae_models[model_config["local_vae_key"]]["vae"]
121
- with torch.no_grad():
122
- encoded = local_vae.encode(img_transformed).latent_dist.sample()
123
- decoded = remote_decode(
124
- endpoint=model_config["endpoint"],
125
- tensor=encoded,
126
- do_scaling=False,
127
- output_type="pt",
128
- return_type="pt",
129
- partial_postprocess=False,
130
- )
131
 
132
  # End timer
133
  processing_time = time.time() - start_time
134
 
135
- decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
136
  reconstructed = decoded_transformed.clip(0, 1)
137
  diff = (original_base - reconstructed).abs()
138
  bw_diff = (diff > tolerance).any(dim=0).float()
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ from diffusers import AutoencoderKL, AutoencoderDC, AutoModel
 
5
  import torchvision.transforms.v2 as transforms
6
  from torchvision.io import read_image
7
  from typing import Dict
 
41
  self.output_transform = transforms.Normalize(mean=[-1], std=[2])
42
  self.vae_models = self._load_all_vaes()
43
 
 
 
 
 
 
 
 
 
 
44
  def _load_all_vaes(self) -> Dict[str, Dict]:
45
+ """Load configurations for all VAE models"""
46
+ vaes = {
47
  "stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device),
48
  "eq-vae-ema": AutoencoderKL.from_pretrained("zelaki/eq-vae-ema").to(self.device),
49
  "eq-sdxl-vae": AutoencoderKL.from_pretrained("KBlueLeaf/EQ-SDXL-VAE").to(self.device),
 
56
  # "dc-ae-f32c32-sana-1.0": AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers").to(self.device),
57
  "FLUX.1-Kontext": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", subfolder="vae").to(self.device),
58
  "FLUX.2": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.2-dev", subfolder="vae").to(self.device),
59
+ "FLUX.2-TinyAutoEncoder": AutoModel.from_pretrained("fal/FLUX.2-Tiny-AutoEncoder", trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device),
60
  }
61
  # Define the desired order of models
62
  order = [
 
64
  "eq-vae-ema",
65
  "eq-sdxl-vae",
66
  "sd-vae-ft-mse",
 
67
  "sdxl-vae",
 
68
  "playground-v2.5",
69
  "stable-diffusion-3-medium",
70
  "FLUX.1",
 
71
  "CogView4-6B",
72
  # "dc-ae-f32c32-sana-1.0",
73
  "FLUX.1-Kontext",
74
  "FLUX.2",
75
+ "FLUX.2-TinyAutoEncoder",
76
  ]
77
 
78
  # Construct the vae_models dictionary in the specified order
79
+ return {name: {"vae": vaes[name], "dtype": torch.bfloat16 if name == "FLUX.2-TinyAutoEncoder" else torch.float32} for name in order}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
82
+ """Process image through a single VAE model"""
83
+ dtype = model_config["dtype"]
84
+ img_transformed = self.input_transform(img).to(dtype).to(self.device).unsqueeze(0)
85
  original_base = self.base_transform(img).cpu()
86
 
87
  # Start timer
88
  start_time = time.time()
89
 
90
+ vae = model_config["vae"]
91
+ with torch.no_grad():
92
+ if isinstance(vae, AutoModel):
93
+ encoded = vae.encode(img_transformed, return_dict=False)
94
+ decoded = vae.decode(encoded, return_dict=False)
95
+ else:
96
  encoded = vae.encode(img_transformed).latent_dist.sample()
97
  decoded = vae.decode(encoded).sample
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # End timer
100
  processing_time = time.time() - start_time
101
 
102
+ decoded_transformed = self.output_transform(decoded.squeeze(0).to(torch.float32)).cpu()
103
  reconstructed = decoded_transformed.clip(0, 1)
104
  diff = (original_base - reconstructed).abs()
105
  bw_diff = (diff > tolerance).any(dim=0).float()