Spaces:
Running
on
Zero
Running
on
Zero
update aligned, fix z-score
Browse files
app.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
# Author: Huzheng Yang
|
| 2 |
# %%
|
| 3 |
import copy
|
|
|
|
| 4 |
import os
|
|
|
|
|
|
|
| 5 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
| 6 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
| 7 |
|
|
@@ -241,7 +244,7 @@ def ncut_run(
|
|
| 241 |
logging_str = ""
|
| 242 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 243 |
# dirty patch for the alignedcut paper
|
| 244 |
-
resolution = (
|
| 245 |
else:
|
| 246 |
resolution = RES_DICT[model_name]
|
| 247 |
logging_str += f"Resolution: {resolution}\n"
|
|
@@ -357,11 +360,18 @@ def ncut_run(
|
|
| 357 |
|
| 358 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 359 |
# dirty patch for the alignedcut paper
|
| 360 |
-
galleries = []
|
| 361 |
-
for i_node in range(rgb.shape[1]):
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
return *galleries, logging_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
if is_lisa == True:
|
| 367 |
# dirty patch for the LISA model
|
|
@@ -457,9 +467,78 @@ def transform_image(image, resolution=(1024, 1024)):
|
|
| 457 |
image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
|
| 458 |
image = image / 255
|
| 459 |
# Normalize
|
| 460 |
-
|
|
|
|
|
|
|
| 461 |
return image
|
| 462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
def load_alignedthreemodel():
|
| 464 |
|
| 465 |
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
|
@@ -687,10 +766,10 @@ def make_input_video_section():
|
|
| 687 |
clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
| 688 |
return input_gallery, submit_button, clear_images_button, max_frames_number
|
| 689 |
|
| 690 |
-
def make_dataset_images_section(advanced=False):
|
| 691 |
|
| 692 |
gr.Markdown('### Load Datasets')
|
| 693 |
-
load_images_button = gr.Button("Load", elem_id="load-images-button", variant='
|
| 694 |
advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
|
| 695 |
with gr.Column() as basic_block:
|
| 696 |
example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
|
|
@@ -700,10 +779,17 @@ def make_dataset_images_section(advanced=False):
|
|
| 700 |
with gr.Row():
|
| 701 |
dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
|
| 702 |
num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
|
| 708 |
if advanced:
|
| 709 |
advanced_block.visible = True
|
|
@@ -1168,12 +1254,18 @@ with demo:
|
|
| 1168 |
with gr.Column(scale=5, min_width=200):
|
| 1169 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1170 |
|
| 1171 |
-
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
| 1172 |
num_images_slider.value = 100
|
| 1173 |
|
|
|
|
| 1174 |
with gr.Column(scale=5, min_width=200):
|
|
|
|
|
|
|
|
|
|
| 1175 |
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
| 1176 |
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
|
|
|
|
|
|
| 1177 |
[
|
| 1178 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1179 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
|
@@ -1185,20 +1277,23 @@ with demo:
|
|
| 1185 |
model_dropdown.visible = False
|
| 1186 |
layer_slider.visible = False
|
| 1187 |
node_type_dropdown.visible = False
|
|
|
|
|
|
|
| 1188 |
# logging text box
|
| 1189 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1190 |
|
| 1191 |
-
galleries = []
|
| 1192 |
-
for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
| 1196 |
-
|
| 1197 |
-
|
| 1198 |
-
|
| 1199 |
|
| 1200 |
|
| 1201 |
-
clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
|
|
|
| 1202 |
|
| 1203 |
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 1204 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
|
@@ -1213,7 +1308,8 @@ with demo:
|
|
| 1213 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1214 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
| 1215 |
],
|
| 1216 |
-
outputs=galleries + [logging_text],
|
|
|
|
| 1217 |
)
|
| 1218 |
|
| 1219 |
with gr.Tab('Compare Models'):
|
|
@@ -1320,4 +1416,4 @@ if DOWNLOAD_ALL_MODELS_DATASETS:
|
|
| 1320 |
demo.launch(share=True)
|
| 1321 |
|
| 1322 |
|
| 1323 |
-
# %%
|
|
|
|
| 1 |
# Author: Huzheng Yang
|
| 2 |
# %%
|
| 3 |
import copy
|
| 4 |
+
from io import BytesIO
|
| 5 |
import os
|
| 6 |
+
|
| 7 |
+
from matplotlib import pyplot as plt
|
| 8 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
| 9 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
| 10 |
|
|
|
|
| 244 |
logging_str = ""
|
| 245 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 246 |
# dirty patch for the alignedcut paper
|
| 247 |
+
resolution = (224, 224)
|
| 248 |
else:
|
| 249 |
resolution = RES_DICT[model_name]
|
| 250 |
logging_str += f"Resolution: {resolution}\n"
|
|
|
|
| 360 |
|
| 361 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 362 |
# dirty patch for the alignedcut paper
|
| 363 |
+
# galleries = []
|
| 364 |
+
# for i_node in range(rgb.shape[1]):
|
| 365 |
+
# _rgb = rgb[:, i_node]
|
| 366 |
+
# galleries.append(to_pil_images(_rgb, target_size=56))
|
| 367 |
+
# return *galleries, logging_str
|
| 368 |
+
pil_images = []
|
| 369 |
+
for i_image in range(rgb.shape[0]):
|
| 370 |
+
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
| 371 |
+
pil_images.append(_im)
|
| 372 |
+
return pil_images, logging_str
|
| 373 |
+
|
| 374 |
+
|
| 375 |
|
| 376 |
if is_lisa == True:
|
| 377 |
# dirty patch for the LISA model
|
|
|
|
| 467 |
image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
|
| 468 |
image = image / 255
|
| 469 |
# Normalize
|
| 470 |
+
mean = [0.485, 0.456, 0.406]
|
| 471 |
+
std = [0.229, 0.224, 0.225]
|
| 472 |
+
image = (image - torch.tensor(mean).view(3, 1, 1)) / torch.tensor(std).view(3, 1, 1)
|
| 473 |
return image
|
| 474 |
|
| 475 |
+
def plot_one_image_36_grid(original_image, tsne_rgb_images):
|
| 476 |
+
mean = [0.485, 0.456, 0.406]
|
| 477 |
+
std = [0.229, 0.224, 0.225]
|
| 478 |
+
original_image = original_image * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
|
| 479 |
+
original_image = torch.clamp(original_image, 0, 1)
|
| 480 |
+
|
| 481 |
+
fig = plt.figure(figsize=(20, 4))
|
| 482 |
+
grid = plt.GridSpec(3, 14, hspace=0.1, wspace=0.1)
|
| 483 |
+
|
| 484 |
+
ax1 = fig.add_subplot(grid[0:2, 0:2])
|
| 485 |
+
img = original_image.cpu().float().numpy().transpose(1, 2, 0)
|
| 486 |
+
|
| 487 |
+
def convert_and_pad_image(np_array, pad_size=20):
|
| 488 |
+
"""
|
| 489 |
+
Converts a NumPy array of shape (height, width, 3) to a PNG image
|
| 490 |
+
and pads the right and bottom sides with a transparent background.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
np_array (numpy.ndarray): Input NumPy array of shape (height, width, 3)
|
| 494 |
+
pad_size (int, optional): Number of pixels to pad on the right and bottom sides. Default is 20.
|
| 495 |
+
|
| 496 |
+
Returns:
|
| 497 |
+
PIL.Image: Padded PNG image with transparent background
|
| 498 |
+
"""
|
| 499 |
+
# Convert NumPy array to PIL Image
|
| 500 |
+
img = Image.fromarray(np_array)
|
| 501 |
+
|
| 502 |
+
# Get the original size
|
| 503 |
+
width, height = img.size
|
| 504 |
+
|
| 505 |
+
# Create a new image with padding and transparent background
|
| 506 |
+
new_width = width + pad_size
|
| 507 |
+
new_height = height + pad_size
|
| 508 |
+
padded_img = Image.new('RGBA', (new_width, new_height), color=(255, 255, 255, 0))
|
| 509 |
+
|
| 510 |
+
# Paste the original image onto the padded image
|
| 511 |
+
padded_img.paste(img, (0, 0))
|
| 512 |
+
|
| 513 |
+
return padded_img
|
| 514 |
+
|
| 515 |
+
img = convert_and_pad_image((img*255).astype(np.uint8))
|
| 516 |
+
ax1.imshow(img)
|
| 517 |
+
ax1.axis('off')
|
| 518 |
+
|
| 519 |
+
model_names = ['CLIP', 'DINO', 'MAE']
|
| 520 |
+
|
| 521 |
+
for i_model, model_name in enumerate(model_names):
|
| 522 |
+
for i_layer in range(12):
|
| 523 |
+
ax = fig.add_subplot(grid[i_model, i_layer+2])
|
| 524 |
+
ax.imshow(tsne_rgb_images[i_layer+12*i_model].cpu().float().numpy())
|
| 525 |
+
ax.axis('off')
|
| 526 |
+
if i_model == 0:
|
| 527 |
+
ax.set_title(f'Layer{i_layer}', fontsize=16)
|
| 528 |
+
if i_layer == 0:
|
| 529 |
+
ax.text(-0.1, 0.5, model_name, va="center", ha="center", fontsize=16, transform=ax.transAxes, rotation=90,)
|
| 530 |
+
plt.tight_layout()
|
| 531 |
+
buf = BytesIO()
|
| 532 |
+
plt.savefig(buf, bbox_inches='tight', pad_inches=0, dpi=100)
|
| 533 |
+
|
| 534 |
+
buf.seek(0) # Move to the start of the BytesIO buffer
|
| 535 |
+
img = Image.open(buf)
|
| 536 |
+
img = img.convert("RGB")
|
| 537 |
+
img = copy.deepcopy(img)
|
| 538 |
+
buf.close()
|
| 539 |
+
plt.close()
|
| 540 |
+
return img
|
| 541 |
+
|
| 542 |
def load_alignedthreemodel():
|
| 543 |
|
| 544 |
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
|
|
|
| 766 |
clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
| 767 |
return input_gallery, submit_button, clear_images_button, max_frames_number
|
| 768 |
|
| 769 |
+
def make_dataset_images_section(advanced=False, is_random=False):
|
| 770 |
|
| 771 |
gr.Markdown('### Load Datasets')
|
| 772 |
+
load_images_button = gr.Button("🟢 Load Images", elem_id="load-images-button", variant='primary')
|
| 773 |
advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio")
|
| 774 |
with gr.Column() as basic_block:
|
| 775 |
example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
|
|
|
|
| 779 |
with gr.Row():
|
| 780 |
dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300)
|
| 781 |
num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
|
| 782 |
+
if not is_random:
|
| 783 |
+
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
|
| 784 |
+
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
|
| 785 |
+
is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
|
| 786 |
+
random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
|
| 787 |
+
if is_random:
|
| 788 |
+
filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox")
|
| 789 |
+
filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False)
|
| 790 |
+
is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox")
|
| 791 |
+
random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
|
| 792 |
+
|
| 793 |
|
| 794 |
if advanced:
|
| 795 |
advanced_block.visible = True
|
|
|
|
| 1254 |
with gr.Column(scale=5, min_width=200):
|
| 1255 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1256 |
|
| 1257 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
|
| 1258 |
num_images_slider.value = 100
|
| 1259 |
|
| 1260 |
+
|
| 1261 |
with gr.Column(scale=5, min_width=200):
|
| 1262 |
+
output_gallery = make_output_images_section()
|
| 1263 |
+
gr.Markdown('### TIP1: use the `full-screen` button, and use `arrow keys` to navigate')
|
| 1264 |
+
gr.Markdown('---')
|
| 1265 |
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
| 1266 |
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
| 1267 |
+
gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT')
|
| 1268 |
+
gr.Markdown('---')
|
| 1269 |
[
|
| 1270 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1271 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
|
|
|
| 1277 |
model_dropdown.visible = False
|
| 1278 |
layer_slider.visible = False
|
| 1279 |
node_type_dropdown.visible = False
|
| 1280 |
+
num_sample_ncut_slider.value = 10000
|
| 1281 |
+
num_sample_tsne_slider.value = 1000
|
| 1282 |
# logging text box
|
| 1283 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1284 |
|
| 1285 |
+
# galleries = []
|
| 1286 |
+
# for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
|
| 1287 |
+
# with gr.Row():
|
| 1288 |
+
# for i_layer in range(1, 13):
|
| 1289 |
+
# with gr.Column(scale=5, min_width=200):
|
| 1290 |
+
# gr.Markdown(f'### {model_name} Layer {i_layer}')
|
| 1291 |
+
# output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
|
| 1292 |
+
# galleries.append(output_gallery)
|
| 1293 |
|
| 1294 |
|
| 1295 |
+
# clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
| 1296 |
+
clear_images_button.click(lambda x: ([], []), outputs=[input_gallery, output_gallery])
|
| 1297 |
|
| 1298 |
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 1299 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
|
|
|
| 1308 |
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1309 |
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
| 1310 |
],
|
| 1311 |
+
# outputs=galleries + [logging_text],
|
| 1312 |
+
outputs=[output_gallery, logging_text],
|
| 1313 |
)
|
| 1314 |
|
| 1315 |
with gr.Tab('Compare Models'):
|
|
|
|
| 1416 |
demo.launch(share=True)
|
| 1417 |
|
| 1418 |
|
| 1419 |
+
# %%
|