Samuel Stevens
commited on
Commit
·
699b9c3
1
Parent(s):
0ab58fa
bug: SAE examples are not highlighted
Browse files- app.py +110 -157
- modeling.py +53 -0
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import functools
|
|
| 2 |
import io
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
-
import
|
| 6 |
import pathlib
|
| 7 |
import typing
|
| 8 |
|
|
@@ -10,17 +10,19 @@ import beartype
|
|
| 10 |
import einops
|
| 11 |
import einops.layers.torch
|
| 12 |
import gradio as gr
|
|
|
|
| 13 |
import saev.activations
|
| 14 |
import saev.config
|
| 15 |
import saev.nn
|
| 16 |
import saev.visuals
|
| 17 |
import torch
|
| 18 |
-
from jaxtyping import Float, Int, UInt8, jaxtyped
|
| 19 |
-
from PIL import Image
|
| 20 |
from torch import Tensor
|
| 21 |
|
| 22 |
import constants
|
| 23 |
import data
|
|
|
|
| 24 |
|
| 25 |
logger = logging.getLogger("app.py")
|
| 26 |
|
|
@@ -29,33 +31,26 @@ logger = logging.getLogger("app.py")
|
|
| 29 |
####################
|
| 30 |
|
| 31 |
|
| 32 |
-
|
| 33 |
-
"""Whether we are debugging."""
|
| 34 |
-
|
| 35 |
-
max_frequency = 1e-2
|
| 36 |
"""Maximum frequency. Any feature that fires more than this is ignored."""
|
| 37 |
|
| 38 |
-
n_sae_latents = 3
|
| 39 |
-
"""Number of SAE latents to show."""
|
| 40 |
-
|
| 41 |
-
n_sae_examples = 4
|
| 42 |
-
"""Number of SAE examples per latent to show."""
|
| 43 |
-
|
| 44 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
-
"""Hardware accelerator, if any."""
|
| 46 |
-
|
| 47 |
RESIZE_SIZE = 512
|
| 48 |
"""Resize shorter size to this size in pixels."""
|
| 49 |
|
| 50 |
CROP_SIZE = (448, 448)
|
| 51 |
"""Crop size in pixels."""
|
| 52 |
|
| 53 |
-
DEVICE =
|
| 54 |
"""Hardware accelerator, if any."""
|
| 55 |
|
| 56 |
CWD = pathlib.Path(".")
|
| 57 |
"""Current working directory."""
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
##########
|
| 61 |
# Models #
|
|
@@ -63,27 +58,7 @@ CWD = pathlib.Path(".")
|
|
| 63 |
|
| 64 |
|
| 65 |
@functools.cache
|
| 66 |
-
def
|
| 67 |
-
vit = (
|
| 68 |
-
saev.activations.WrappedVisionTransformer(
|
| 69 |
-
saev.config.Activations(
|
| 70 |
-
model_family="dinov2",
|
| 71 |
-
model_ckpt="dinov2_vitb14_reg",
|
| 72 |
-
layers=[-2],
|
| 73 |
-
n_patches_per_img=256,
|
| 74 |
-
)
|
| 75 |
-
)
|
| 76 |
-
.to(DEVICE)
|
| 77 |
-
.eval()
|
| 78 |
-
)
|
| 79 |
-
vit_transform = saev.activations.make_img_transform("dinov2", "dinov2_vitb14_reg")
|
| 80 |
-
logger.info("Loaded ViT.")
|
| 81 |
-
|
| 82 |
-
return vit, vit_transform
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
@functools.cache
|
| 86 |
-
def load_sae() -> saev.nn.SparseAutoencoder:
|
| 87 |
"""
|
| 88 |
Loads a sparse autoencoder from disk.
|
| 89 |
"""
|
|
@@ -102,37 +77,12 @@ def load_clf() -> torch.nn.Module:
|
|
| 102 |
buffer = io.BytesIO(fd.read())
|
| 103 |
|
| 104 |
model = torch.nn.Linear(**kwargs)
|
| 105 |
-
state_dict = torch.load(buffer, weights_only=True, map_location=
|
| 106 |
model.load_state_dict(state_dict)
|
| 107 |
-
model = model.to(
|
| 108 |
return model
|
| 109 |
|
| 110 |
|
| 111 |
-
class RestOfDinoV2(torch.nn.Module):
|
| 112 |
-
def __init__(self, *, n_end_layers: int):
|
| 113 |
-
super().__init__()
|
| 114 |
-
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg")
|
| 115 |
-
self.n_end_layers = n_end_layers
|
| 116 |
-
|
| 117 |
-
def forward_start(self, x: Float[Tensor, "batch channels width height"]):
|
| 118 |
-
x_BPD = self.vit.prepare_tokens_with_masks(x)
|
| 119 |
-
for blk in self.vit.blocks[: -self.n_end_layers]:
|
| 120 |
-
x_BPD = blk(x_BPD)
|
| 121 |
-
|
| 122 |
-
return x_BPD
|
| 123 |
-
|
| 124 |
-
def forward_end(self, x_BPD: Float[Tensor, "batch n_patches dim"]):
|
| 125 |
-
for blk in self.vit.blocks[-self.n_end_layers :]:
|
| 126 |
-
x_BPD = blk(x_BPD)
|
| 127 |
-
|
| 128 |
-
x_BPD = self.vit.norm(x_BPD)
|
| 129 |
-
return x_BPD[:, self.vit.num_register_tokens + 1 :]
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
rest_of_vit = RestOfDinoV2(n_end_layers=1)
|
| 133 |
-
rest_of_vit = rest_of_vit.to(device)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
####################
|
| 137 |
# Global Variables #
|
| 138 |
####################
|
|
@@ -143,13 +93,23 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
|
|
| 143 |
return torch.load(path, weights_only=True, map_location="cpu")
|
| 144 |
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
# mask = mask & (sparsity < max_frequency)
|
| 153 |
|
| 154 |
|
| 155 |
############
|
|
@@ -157,37 +117,42 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
|
|
| 157 |
############
|
| 158 |
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
-
#
|
| 180 |
-
|
| 181 |
-
# root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/"
|
| 182 |
-
# ),
|
| 183 |
-
# img_transform=v2.Compose([
|
| 184 |
-
# v2.Resize(size=(256, 256)),
|
| 185 |
-
# v2.CenterCrop(size=(224, 224)),
|
| 186 |
-
# v2.ToImage(),
|
| 187 |
-
# v2.ToDtype(torch.float32, scale=True),
|
| 188 |
-
# v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
|
| 189 |
-
# ]),
|
| 190 |
-
# )
|
| 191 |
|
| 192 |
|
| 193 |
#######################
|
|
@@ -202,12 +167,14 @@ class Example(typing.TypedDict):
|
|
| 202 |
Used to store examples of SAE latent activations for visualization.
|
| 203 |
"""
|
| 204 |
|
|
|
|
|
|
|
| 205 |
orig_url: str
|
| 206 |
"""The URL or path to access the original example image."""
|
| 207 |
highlighted_url: str
|
| 208 |
"""The URL or path to access the SAE-highlighted image."""
|
| 209 |
-
|
| 210 |
-
"""
|
| 211 |
|
| 212 |
|
| 213 |
@beartype.beartype
|
|
@@ -249,64 +216,73 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
|
|
| 249 |
if not patches:
|
| 250 |
return []
|
| 251 |
|
| 252 |
-
|
| 253 |
-
sae = load_sae()
|
| 254 |
|
| 255 |
img = data.get_image(image_i)
|
| 256 |
|
| 257 |
-
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
- (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE)
|
| 263 |
) / constants.DINOV2_IMAGENET1K_SCALAR
|
| 264 |
|
| 265 |
-
|
| 266 |
-
#
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
top_img_i, top_values = load_tensors(model_cfg)
|
| 271 |
-
logger.info("Loaded top SAE activations for '%s'.", model_name)
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
for i in patches
|
| 276 |
-
]).to(device)
|
| 277 |
|
| 278 |
-
|
| 279 |
-
f_x_S = f_x_MS.sum(axis=0)
|
| 280 |
|
| 281 |
latents = torch.argsort(f_x_S, descending=True).cpu()
|
| 282 |
-
latents = latents[mask[latents]][:
|
| 283 |
|
| 284 |
-
|
| 285 |
for latent in latents:
|
| 286 |
-
|
| 287 |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
|
| 288 |
if i_im in seen_i_im:
|
| 289 |
continue
|
| 290 |
|
| 291 |
-
|
| 292 |
-
elems.append(
|
| 293 |
-
saev.visuals.GridElement(example["image"], example["label"], values_p)
|
| 294 |
-
)
|
| 295 |
seen_i_im.add(i_im)
|
|
|
|
|
|
|
| 296 |
|
| 297 |
# How to scale values.
|
| 298 |
upper = None
|
| 299 |
if top_values[latent].numel() > 0:
|
| 300 |
upper = top_values[latent].max().item()
|
| 301 |
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
-
return
|
| 310 |
|
| 311 |
|
| 312 |
@torch.inference_mode
|
|
@@ -416,29 +392,6 @@ def upsample(
|
|
| 416 |
)
|
| 417 |
|
| 418 |
|
| 419 |
-
@beartype.beartype
|
| 420 |
-
def make_img(
|
| 421 |
-
elem: saev.visuals.GridElement, *, upper: float | None = None
|
| 422 |
-
) -> Image.Image:
|
| 423 |
-
# Resize to 256x256 and crop to 224x224
|
| 424 |
-
resize_size_px = (512, 512)
|
| 425 |
-
resize_w_px, resize_h_px = resize_size_px
|
| 426 |
-
crop_size_px = (448, 448)
|
| 427 |
-
crop_w_px, crop_h_px = crop_size_px
|
| 428 |
-
crop_coords_px = (
|
| 429 |
-
(resize_w_px - crop_w_px) // 2,
|
| 430 |
-
(resize_h_px - crop_h_px) // 2,
|
| 431 |
-
(resize_w_px + crop_w_px) // 2,
|
| 432 |
-
(resize_h_px + crop_h_px) // 2,
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
img = elem.img.resize(resize_size_px).crop(crop_coords_px)
|
| 436 |
-
img = saev.imaging.add_highlights(
|
| 437 |
-
img, elem.patches.numpy(), upper=upper, opacity=0.5
|
| 438 |
-
)
|
| 439 |
-
return img
|
| 440 |
-
|
| 441 |
-
|
| 442 |
with gr.Blocks() as demo:
|
| 443 |
image_number = gr.Number(label="Validation Example")
|
| 444 |
|
|
|
|
| 2 |
import io
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
+
import math
|
| 6 |
import pathlib
|
| 7 |
import typing
|
| 8 |
|
|
|
|
| 10 |
import einops
|
| 11 |
import einops.layers.torch
|
| 12 |
import gradio as gr
|
| 13 |
+
import numpy as np
|
| 14 |
import saev.activations
|
| 15 |
import saev.config
|
| 16 |
import saev.nn
|
| 17 |
import saev.visuals
|
| 18 |
import torch
|
| 19 |
+
from jaxtyping import Bool, Float, Int, UInt8, jaxtyped
|
| 20 |
+
from PIL import Image, ImageDraw
|
| 21 |
from torch import Tensor
|
| 22 |
|
| 23 |
import constants
|
| 24 |
import data
|
| 25 |
+
import modeling
|
| 26 |
|
| 27 |
logger = logging.getLogger("app.py")
|
| 28 |
|
|
|
|
| 31 |
####################
|
| 32 |
|
| 33 |
|
| 34 |
+
MAX_FREQ = 1e-2
|
|
|
|
|
|
|
|
|
|
| 35 |
"""Maximum frequency. Any feature that fires more than this is ignored."""
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
RESIZE_SIZE = 512
|
| 38 |
"""Resize shorter size to this size in pixels."""
|
| 39 |
|
| 40 |
CROP_SIZE = (448, 448)
|
| 41 |
"""Crop size in pixels."""
|
| 42 |
|
| 43 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
"""Hardware accelerator, if any."""
|
| 45 |
|
| 46 |
CWD = pathlib.Path(".")
|
| 47 |
"""Current working directory."""
|
| 48 |
|
| 49 |
+
N_SAE_LATENTS = 3
|
| 50 |
+
"""Number of SAE latents to show."""
|
| 51 |
+
|
| 52 |
+
N_LATENT_EXAMPLES = 4
|
| 53 |
+
"""Number of examples per SAE latent to show."""
|
| 54 |
|
| 55 |
##########
|
| 56 |
# Models #
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
@functools.cache
|
| 61 |
+
def load_sae(device: str) -> saev.nn.SparseAutoencoder:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
"""
|
| 63 |
Loads a sparse autoencoder from disk.
|
| 64 |
"""
|
|
|
|
| 77 |
buffer = io.BytesIO(fd.read())
|
| 78 |
|
| 79 |
model = torch.nn.Linear(**kwargs)
|
| 80 |
+
state_dict = torch.load(buffer, weights_only=True, map_location=DEVICE)
|
| 81 |
model.load_state_dict(state_dict)
|
| 82 |
+
model = model.to(DEVICE).eval()
|
| 83 |
return model
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
####################
|
| 87 |
# Global Variables #
|
| 88 |
####################
|
|
|
|
| 93 |
return torch.load(path, weights_only=True, map_location="cpu")
|
| 94 |
|
| 95 |
|
| 96 |
+
@functools.cache
|
| 97 |
+
def load_tensors() -> tuple[
|
| 98 |
+
Int[Tensor, "d_sae k"],
|
| 99 |
+
UInt8[Tensor, "d_sae k n_patches"],
|
| 100 |
+
Bool[Tensor, " d_sae"],
|
| 101 |
+
]:
|
| 102 |
+
"""
|
| 103 |
+
Loads the tensors for the SAE for ADE20K.
|
| 104 |
+
"""
|
| 105 |
+
top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
|
| 106 |
+
top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
|
| 107 |
+
sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
|
| 108 |
|
| 109 |
+
mask = torch.ones(sparsity.shape, dtype=bool)
|
| 110 |
+
mask = mask & (sparsity < MAX_FREQ)
|
| 111 |
|
| 112 |
+
return top_img_i, top_values, mask
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
############
|
|
|
|
| 117 |
############
|
| 118 |
|
| 119 |
|
| 120 |
+
@jaxtyped(typechecker=beartype.beartype)
|
| 121 |
+
def add_highlights(
|
| 122 |
+
img: Image.Image,
|
| 123 |
+
patches: Float[np.ndarray, " n_patches"],
|
| 124 |
+
*,
|
| 125 |
+
upper: int | None = None,
|
| 126 |
+
opacity: float = 0.9,
|
| 127 |
+
) -> Image.Image:
|
| 128 |
+
breakpoint()
|
| 129 |
+
if not len(patches):
|
| 130 |
+
return img
|
| 131 |
+
|
| 132 |
+
iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches)))
|
| 133 |
+
iw_px, ih_px = img.size
|
| 134 |
+
pw_px, ph_px = iw_px // iw_np, ih_px // ih_np
|
| 135 |
+
assert iw_np * ih_np == len(patches)
|
| 136 |
+
|
| 137 |
+
# Create a transparent overlay
|
| 138 |
+
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 139 |
+
draw = ImageDraw.Draw(overlay)
|
| 140 |
+
|
| 141 |
+
# Using semi-transparent red (255, 0, 0, alpha)
|
| 142 |
+
for p, val in enumerate(patches):
|
| 143 |
+
assert upper is not None
|
| 144 |
+
val /= upper + 1e-9
|
| 145 |
+
x_np, y_np = p % iw_np, p // ih_np
|
| 146 |
+
draw.rectangle(
|
| 147 |
+
[
|
| 148 |
+
(x_np * pw_px, y_np * ph_px),
|
| 149 |
+
(x_np * pw_px + pw_px, y_np * ph_px + ph_px),
|
| 150 |
+
],
|
| 151 |
+
fill=(int(val * 256), 0, 0, int(opacity * val * 256)),
|
| 152 |
+
)
|
| 153 |
|
| 154 |
+
# Composite the original image and the overlay
|
| 155 |
+
return Image.alpha_composite(img.convert("RGBA"), overlay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
#######################
|
|
|
|
| 167 |
Used to store examples of SAE latent activations for visualization.
|
| 168 |
"""
|
| 169 |
|
| 170 |
+
index: int
|
| 171 |
+
"""Dataset index."""
|
| 172 |
orig_url: str
|
| 173 |
"""The URL or path to access the original example image."""
|
| 174 |
highlighted_url: str
|
| 175 |
"""The URL or path to access the SAE-highlighted image."""
|
| 176 |
+
seg_url: str
|
| 177 |
+
"""Base64-encoded version of the colored segmentation map."""
|
| 178 |
|
| 179 |
|
| 180 |
@beartype.beartype
|
|
|
|
| 216 |
if not patches:
|
| 217 |
return []
|
| 218 |
|
| 219 |
+
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
| 220 |
+
sae = load_sae(DEVICE)
|
| 221 |
|
| 222 |
img = data.get_image(image_i)
|
| 223 |
|
| 224 |
+
x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
|
| 225 |
|
| 226 |
+
x_BPD = split_vit.forward_start(x_BCWH)
|
| 227 |
+
x_BPD = (
|
| 228 |
+
x_BPD.clamp(-1e-5, 1e5) - (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE)
|
|
|
|
| 229 |
) / constants.DINOV2_IMAGENET1K_SCALAR
|
| 230 |
|
| 231 |
+
# Need to pick out the right patches
|
| 232 |
+
# + 1 + 4 for 1 [CLS] token and 4 register tokens
|
| 233 |
+
x_PD = x_BPD[0, [p + 1 + 4 for p in patches]]
|
| 234 |
+
_, f_x_PS, _ = sae(x_PD)
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
+
f_x_S = einops.reduce(f_x_PS, "patches n_latents -> n_latents", "sum")
|
| 237 |
+
logger.info("Got SAE activations.")
|
|
|
|
|
|
|
| 238 |
|
| 239 |
+
top_img_i, top_values, mask = load_tensors()
|
|
|
|
| 240 |
|
| 241 |
latents = torch.argsort(f_x_S, descending=True).cpu()
|
| 242 |
+
latents = latents[mask[latents]][:N_SAE_LATENTS].tolist()
|
| 243 |
|
| 244 |
+
sae_activations = []
|
| 245 |
for latent in latents:
|
| 246 |
+
pairs, seen_i_im = [], set()
|
| 247 |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
|
| 248 |
if i_im in seen_i_im:
|
| 249 |
continue
|
| 250 |
|
| 251 |
+
pairs.append((i_im, values_p))
|
|
|
|
|
|
|
|
|
|
| 252 |
seen_i_im.add(i_im)
|
| 253 |
+
if len(pairs) >= N_LATENT_EXAMPLES:
|
| 254 |
+
break
|
| 255 |
|
| 256 |
# How to scale values.
|
| 257 |
upper = None
|
| 258 |
if top_values[latent].numel() > 0:
|
| 259 |
upper = top_values[latent].max().item()
|
| 260 |
|
| 261 |
+
examples = []
|
| 262 |
+
for i_im, values_p in pairs:
|
| 263 |
+
seg_sized = data.to_sized(data.get_seg(i_im))
|
| 264 |
+
img_sized = data.to_sized(data.get_image(i_im))
|
| 265 |
+
|
| 266 |
+
seg_u8_sized = data.to_u8(seg_sized)
|
| 267 |
+
seg_img_sized = data.u8_to_img(seg_u8_sized)
|
| 268 |
|
| 269 |
+
highlighted_sized = add_highlights(
|
| 270 |
+
img_sized, values_p.float().numpy(), upper=upper
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
examples.append({
|
| 274 |
+
"index": i_im,
|
| 275 |
+
"orig_url": data.img_to_base64(img_sized),
|
| 276 |
+
"highlighted_url": data.img_to_base64(highlighted_sized),
|
| 277 |
+
"seg_url": data.img_to_base64(seg_img_sized),
|
| 278 |
+
})
|
| 279 |
|
| 280 |
+
sae_activations.append({
|
| 281 |
+
"latent": latent,
|
| 282 |
+
"examples": examples,
|
| 283 |
+
})
|
| 284 |
|
| 285 |
+
return sae_activations
|
| 286 |
|
| 287 |
|
| 288 |
@torch.inference_mode
|
|
|
|
| 392 |
)
|
| 393 |
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
with gr.Blocks() as demo:
|
| 396 |
image_number = gr.Number(label="Validation Example")
|
| 397 |
|
modeling.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
import typing
|
| 4 |
+
|
| 5 |
+
import beartype
|
| 6 |
+
import torch
|
| 7 |
+
from jaxtyping import Float, jaxtyped
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("modeling.py")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@jaxtyped(typechecker=beartype.beartype)
|
| 15 |
+
class SplitDinov2(torch.nn.Module):
|
| 16 |
+
def __init__(self, *, split_at: int):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval()
|
| 20 |
+
self.split_at = split_at
|
| 21 |
+
|
| 22 |
+
def forward_start(
|
| 23 |
+
self, x: Float[Tensor, "batch channels width height"]
|
| 24 |
+
) -> Float[Tensor, "batch patches dim"]:
|
| 25 |
+
x_BPD = self.vit.prepare_tokens_with_masks(x)
|
| 26 |
+
for blk in self.vit.blocks[: self.split_at]:
|
| 27 |
+
x_BPD = blk(x_BPD)
|
| 28 |
+
|
| 29 |
+
return x_BPD
|
| 30 |
+
|
| 31 |
+
def forward_end(
|
| 32 |
+
self, x_BPD: Float[Tensor, "batch n_patches dim"]
|
| 33 |
+
) -> Float[Tensor, "batch patches dim"]:
|
| 34 |
+
for blk in self.vit.blocks[-self.split_at :]:
|
| 35 |
+
x_BPD = blk(x_BPD)
|
| 36 |
+
|
| 37 |
+
x_BPD = self.vit.norm(x_BPD)
|
| 38 |
+
return x_BPD[:, self.vit.num_register_tokens + 1 :]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@functools.cache
|
| 42 |
+
def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]:
|
| 43 |
+
vit = SplitDinov2(split_at=11).to(device)
|
| 44 |
+
vit_transform = v2.Compose([
|
| 45 |
+
v2.Resize(size=(256, 256)),
|
| 46 |
+
v2.CenterCrop(size=(224, 224)),
|
| 47 |
+
v2.ToImage(),
|
| 48 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 49 |
+
v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
|
| 50 |
+
])
|
| 51 |
+
logger.info("Loaded ViT.")
|
| 52 |
+
|
| 53 |
+
return vit, vit_transform
|