Samuel Stevens
commited on
Commit
·
852b07a
1
Parent(s):
c4ee5c3
wip: adding mod preds
Browse files
app.py
CHANGED
|
@@ -46,7 +46,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 46 |
CWD = pathlib.Path(".")
|
| 47 |
"""Current working directory."""
|
| 48 |
|
| 49 |
-
N_SAE_LATENTS =
|
| 50 |
"""Number of SAE latents to show."""
|
| 51 |
|
| 52 |
N_LATENT_EXAMPLES = 4
|
|
@@ -289,7 +289,7 @@ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
|
|
| 289 |
|
| 290 |
|
| 291 |
@torch.inference_mode
|
| 292 |
-
def
|
| 293 |
img = data.get_img(i)
|
| 294 |
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
| 295 |
|
|
@@ -331,16 +331,10 @@ def map_range(
|
|
| 331 |
return c + (x - a) * (d - c) / (b - a)
|
| 332 |
|
| 333 |
|
|
|
|
| 334 |
@torch.inference_mode
|
| 335 |
-
def
|
| 336 |
-
|
| 337 |
-
latent1: int,
|
| 338 |
-
latent2: int,
|
| 339 |
-
latent3: int,
|
| 340 |
-
value1: float,
|
| 341 |
-
value2: float,
|
| 342 |
-
value3: float,
|
| 343 |
-
) -> list[Image.Image | list[int]]:
|
| 344 |
sample = vit_dataset[i]
|
| 345 |
x = sample["image"][None, ...].to(device)
|
| 346 |
x_BPD = rest_of_vit.forward_start(x)
|
|
@@ -429,38 +423,38 @@ with gr.Blocks() as demo:
|
|
| 429 |
api_name="get-sae-latents",
|
| 430 |
)
|
| 431 |
|
| 432 |
-
|
| 433 |
-
# get-preds #
|
| 434 |
-
|
| 435 |
|
| 436 |
# Outputs
|
| 437 |
-
|
| 438 |
|
| 439 |
get_pred_labels_btn = gr.Button(value="Get Predictions")
|
| 440 |
get_pred_labels_btn.click(
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
| 442 |
)
|
| 443 |
|
| 444 |
-
|
| 445 |
-
#
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
#
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
# outputs=[semseg_img, semseg_colors],
|
| 462 |
-
# api_name="get-modified-labels",
|
| 463 |
-
# )
|
| 464 |
|
| 465 |
if __name__ == "__main__":
|
| 466 |
demo.launch()
|
|
|
|
| 46 |
CWD = pathlib.Path(".")
|
| 47 |
"""Current working directory."""
|
| 48 |
|
| 49 |
+
N_SAE_LATENTS = 2
|
| 50 |
"""Number of SAE latents to show."""
|
| 51 |
|
| 52 |
N_LATENT_EXAMPLES = 4
|
|
|
|
| 289 |
|
| 290 |
|
| 291 |
@torch.inference_mode
|
| 292 |
+
def get_orig_preds(i: int) -> dict[str, object]:
|
| 293 |
img = data.get_img(i)
|
| 294 |
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
| 295 |
|
|
|
|
| 331 |
return c + (x - a) * (d - c) / (b - a)
|
| 332 |
|
| 333 |
|
| 334 |
+
@beartype.beartype
|
| 335 |
@torch.inference_mode
|
| 336 |
+
def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
|
| 337 |
+
breakpoint()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
sample = vit_dataset[i]
|
| 339 |
x = sample["image"][None, ...].to(device)
|
| 340 |
x_BPD = rest_of_vit.forward_start(x)
|
|
|
|
| 423 |
api_name="get-sae-latents",
|
| 424 |
)
|
| 425 |
|
| 426 |
+
##################
|
| 427 |
+
# get-orig-preds #
|
| 428 |
+
##################
|
| 429 |
|
| 430 |
# Outputs
|
| 431 |
+
get_orig_preds_out = gr.JSON(label="get_orig_preds_out", value=[])
|
| 432 |
|
| 433 |
get_pred_labels_btn = gr.Button(value="Get Predictions")
|
| 434 |
get_pred_labels_btn.click(
|
| 435 |
+
get_orig_preds,
|
| 436 |
+
inputs=[img_number],
|
| 437 |
+
outputs=[get_orig_preds_out],
|
| 438 |
+
api_name="get-orig-preds",
|
| 439 |
)
|
| 440 |
|
| 441 |
+
#################
|
| 442 |
+
# get-mod-preds #
|
| 443 |
+
#################
|
| 444 |
+
|
| 445 |
+
# Inputs
|
| 446 |
+
latents_json = gr.JSON(label="Modified Latents", value={})
|
| 447 |
+
|
| 448 |
+
# Outputs
|
| 449 |
+
get_mod_preds_out = gr.JSON(label="get_mod_preds_out", value=[])
|
| 450 |
+
|
| 451 |
+
get_pred_labels_btn = gr.Button(value="Get Predictions")
|
| 452 |
+
get_pred_labels_btn.click(
|
| 453 |
+
get_mod_preds,
|
| 454 |
+
inputs=[img_number, latents_json],
|
| 455 |
+
outputs=[get_mod_preds_out],
|
| 456 |
+
api_name="get-mod-preds",
|
| 457 |
+
)
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
if __name__ == "__main__":
|
| 460 |
demo.launch()
|