Spaces:
Running
on
Zero
Running
on
Zero
fix msg, faster eig norm
Browse files- app.py +32 -9
- fps_cluster.py +2 -1
app.py
CHANGED
|
@@ -2232,6 +2232,12 @@ with demo:
|
|
| 2232 |
def __run_fn(*args, **kwargs):
|
| 2233 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
| 2234 |
rgb_gallery = to_pil_images(rgb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2235 |
return eigvecs, rgb, rgb_gallery, logging_str
|
| 2236 |
|
| 2237 |
submit_button.click(
|
|
@@ -2408,8 +2414,8 @@ with demo:
|
|
| 2408 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
| 2409 |
with gr.Accordion("Outputs", open=True):
|
| 2410 |
gr.Markdown("""
|
| 2411 |
-
1. spectral-tSNE tree: ◆
|
| 2412 |
-
2. Cluster Heatmap: max cosine similarity to
|
| 2413 |
""")
|
| 2414 |
with gr.Column(scale=5, min_width=200):
|
| 2415 |
prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
|
|
@@ -2427,6 +2433,7 @@ with demo:
|
|
| 2427 |
tsne_plot.change(updaste_tsne_plot_change_granularity,
|
| 2428 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
|
| 2429 |
outputs=[tsne_prompt_image])
|
|
|
|
| 2430 |
run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
|
| 2431 |
inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 2432 |
# output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
|
|
@@ -2500,8 +2507,8 @@ with demo:
|
|
| 2500 |
x = int(x * w)
|
| 2501 |
y = int(y * h)
|
| 2502 |
eigvec = _eigvec[y, x]
|
| 2503 |
-
|
| 2504 |
-
closest_idx = np.
|
| 2505 |
return closest_idx
|
| 2506 |
|
| 2507 |
def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
|
|
@@ -2529,13 +2536,29 @@ with demo:
|
|
| 2529 |
output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
|
| 2530 |
|
| 2531 |
# draw heatmap for the connected components
|
|
|
|
| 2532 |
connected_eigvecs = fps_eigvecs[connected_idxs]
|
| 2533 |
-
left =
|
| 2534 |
-
right =
|
| 2535 |
-
left = F.normalize(left, p=2, dim=-1)
|
| 2536 |
-
right = F.normalize(right, p=2, dim=-1)
|
|
|
|
| 2537 |
similarity = left @ right.T
|
| 2538 |
-
similarity = similarity.max(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2539 |
hot_map = matplotlib.colormaps['hot']
|
| 2540 |
heatmap = hot_map(similarity)[..., :3] # B H W 3
|
| 2541 |
heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
|
|
|
| 2232 |
def __run_fn(*args, **kwargs):
|
| 2233 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
| 2234 |
rgb_gallery = to_pil_images(rgb)
|
| 2235 |
+
# normalize the eigvecs
|
| 2236 |
+
eigvecs = torch.tensor(eigvecs)
|
| 2237 |
+
if torch.cuda.is_available():
|
| 2238 |
+
eigvecs = eigvecs.cuda()
|
| 2239 |
+
eigvecs = F.normalize(eigvecs, p=2, dim=-1)
|
| 2240 |
+
eigvecs = eigvecs.cpu().numpy()
|
| 2241 |
return eigvecs, rgb, rgb_gallery, logging_str
|
| 2242 |
|
| 2243 |
submit_button.click(
|
|
|
|
| 2414 |
gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.")
|
| 2415 |
with gr.Accordion("Outputs", open=True):
|
| 2416 |
gr.Markdown("""
|
| 2417 |
+
1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked dot.
|
| 2418 |
+
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
| 2419 |
""")
|
| 2420 |
with gr.Column(scale=5, min_width=200):
|
| 2421 |
prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
|
|
|
|
| 2433 |
tsne_plot.change(updaste_tsne_plot_change_granularity,
|
| 2434 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
|
| 2435 |
outputs=[tsne_prompt_image])
|
| 2436 |
+
prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
| 2437 |
run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
|
| 2438 |
inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 2439 |
# output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
|
|
|
|
| 2507 |
x = int(x * w)
|
| 2508 |
y = int(y * h)
|
| 2509 |
eigvec = _eigvec[y, x]
|
| 2510 |
+
sim = fps_eigvecs @ eigvec
|
| 2511 |
+
closest_idx = np.argmax(sim)
|
| 2512 |
return closest_idx
|
| 2513 |
|
| 2514 |
def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
|
|
|
|
| 2536 |
output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
|
| 2537 |
|
| 2538 |
# draw heatmap for the connected components
|
| 2539 |
+
## cosine distance
|
| 2540 |
connected_eigvecs = fps_eigvecs[connected_idxs]
|
| 2541 |
+
left = eigvecs.astype(np.float32) # B H W C
|
| 2542 |
+
right = connected_eigvecs.astype(np.float32) # N C
|
| 2543 |
+
# left = F.normalize(left, p=2, dim=-1)
|
| 2544 |
+
# right = F.normalize(right, p=2, dim=-1)
|
| 2545 |
+
# eigvec is already normalized when saved to gr.State
|
| 2546 |
similarity = left @ right.T
|
| 2547 |
+
similarity = similarity.max(axis=-1) # B H W N
|
| 2548 |
+
## euclidean distance
|
| 2549 |
+
# b, h, w = tsne3d_rgb.shape[:3]
|
| 2550 |
+
# tsne3d_rgb = tsne3d_rgb.reshape(b*h*w, 3)
|
| 2551 |
+
# connected_rgb = tsne3d_rgb[fps_indices][connected_idxs]
|
| 2552 |
+
# left = torch.tensor(tsne3d_rgb).float() # (B H W) 3
|
| 2553 |
+
# right = torch.tensor(connected_rgb).float() # N 3
|
| 2554 |
+
# # dist B H W N
|
| 2555 |
+
# dist = left[:, None] - right[None]
|
| 2556 |
+
# dist = torch.sqrt((dist ** 2).sum(dim=-1))
|
| 2557 |
+
# dist = dist.min(dim=-1).values # B H W
|
| 2558 |
+
# dist = dist.reshape(b, h, w)
|
| 2559 |
+
# gr.Info(f"dist: min={dist.min().item()}, max={dist.max().item()}, mean={dist.mean().item()}", 3)
|
| 2560 |
+
# similarity = 1 - dist
|
| 2561 |
+
|
| 2562 |
hot_map = matplotlib.colormaps['hot']
|
| 2563 |
heatmap = hot_map(similarity)[..., :3] # B H W 3
|
| 2564 |
heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
fps_cluster.py
CHANGED
|
@@ -5,7 +5,8 @@ import torch
|
|
| 5 |
|
| 6 |
def build_tree(all_dots):
|
| 7 |
num_sample = all_dots.shape[0]
|
| 8 |
-
center = all_dots.mean(axis=0)
|
|
|
|
| 9 |
distances_to_center = np.linalg.norm(all_dots - center, axis=1)
|
| 10 |
start_idx = np.argmin(distances_to_center)
|
| 11 |
indices = [start_idx]
|
|
|
|
| 5 |
|
| 6 |
def build_tree(all_dots):
|
| 7 |
num_sample = all_dots.shape[0]
|
| 8 |
+
# center = all_dots.mean(axis=0)
|
| 9 |
+
center = np.median(all_dots, axis=0)
|
| 10 |
distances_to_center = np.linalg.norm(all_dots - center, axis=1)
|
| 11 |
start_idx = np.argmin(distances_to_center)
|
| 12 |
indices = [start_idx]
|