Samuel Stevens
commited on
Commit
·
0ab58fa
1
Parent(s):
e508563
Use cloudflare for ade20k images
Browse files- app.py +9 -12
- constants.py +0 -1
- data.py +15 -96
app.py
CHANGED
|
@@ -143,9 +143,9 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
|
|
| 143 |
return torch.load(path, weights_only=True, map_location="cpu")
|
| 144 |
|
| 145 |
|
| 146 |
-
top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
|
| 147 |
-
top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
|
| 148 |
-
sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
|
| 149 |
|
| 150 |
|
| 151 |
# mask = torch.ones((sae.cfg.d_sae), dtype=bool)
|
|
@@ -231,14 +231,13 @@ class SaeActivation(typing.TypedDict):
|
|
| 231 |
|
| 232 |
|
| 233 |
@beartype.beartype
|
| 234 |
-
def get_image(
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
seg_sized = data.to_sized(sample["segmentation"])
|
| 238 |
seg_u8_sized = data.to_u8(seg_sized)
|
| 239 |
seg_img_sized = data.u8_to_img(seg_u8_sized)
|
| 240 |
|
| 241 |
-
return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized),
|
| 242 |
|
| 243 |
|
| 244 |
@beartype.beartype
|
|
@@ -253,9 +252,9 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
|
|
| 253 |
vit, vit_transform = load_vit()
|
| 254 |
sae = load_sae()
|
| 255 |
|
| 256 |
-
|
| 257 |
|
| 258 |
-
x = vit_transform(
|
| 259 |
|
| 260 |
_, vit_acts_BLPD = vit(x)
|
| 261 |
vit_acts_PD = (
|
|
@@ -268,8 +267,6 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
|
|
| 268 |
acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
|
| 269 |
logger.info("Got SAE activations.")
|
| 270 |
|
| 271 |
-
breakpoint()
|
| 272 |
-
|
| 273 |
top_img_i, top_values = load_tensors(model_cfg)
|
| 274 |
logger.info("Loaded top SAE activations for '%s'.", model_name)
|
| 275 |
|
|
|
|
| 143 |
return torch.load(path, weights_only=True, map_location="cpu")
|
| 144 |
|
| 145 |
|
| 146 |
+
# top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
|
| 147 |
+
# top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
|
| 148 |
+
# sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
|
| 149 |
|
| 150 |
|
| 151 |
# mask = torch.ones((sae.cfg.d_sae), dtype=bool)
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
@beartype.beartype
|
| 234 |
+
def get_image(i: int) -> tuple[str, str, int]:
|
| 235 |
+
img_sized = data.to_sized(data.get_image(i))
|
| 236 |
+
seg_sized = data.to_sized(data.get_seg(i))
|
|
|
|
| 237 |
seg_u8_sized = data.to_u8(seg_sized)
|
| 238 |
seg_img_sized = data.u8_to_img(seg_u8_sized)
|
| 239 |
|
| 240 |
+
return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized), i
|
| 241 |
|
| 242 |
|
| 243 |
@beartype.beartype
|
|
|
|
| 252 |
vit, vit_transform = load_vit()
|
| 253 |
sae = load_sae()
|
| 254 |
|
| 255 |
+
img = data.get_image(image_i)
|
| 256 |
|
| 257 |
+
x = vit_transform(img)[None, ...].to(DEVICE)
|
| 258 |
|
| 259 |
_, vit_acts_BLPD = vit(x)
|
| 260 |
vit_acts_PD = (
|
|
|
|
| 267 |
acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
|
| 268 |
logger.info("Got SAE activations.")
|
| 269 |
|
|
|
|
|
|
|
| 270 |
top_img_i, top_values = load_tensors(model_cfg)
|
| 271 |
logger.info("Loaded top SAE activations for '%s'.", model_name)
|
| 272 |
|
constants.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
-
|
| 4 |
DINOV2_IMAGENET1K_SCALAR = 2.0181241035461426
|
| 5 |
|
| 6 |
|
|
|
|
| 1 |
import torch
|
| 2 |
|
|
|
|
| 3 |
DINOV2_IMAGENET1K_SCALAR = 2.0181241035461426
|
| 4 |
|
| 5 |
|
data.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
| 1 |
import base64
|
| 2 |
-
import dataclasses
|
| 3 |
import functools
|
| 4 |
import io
|
| 5 |
import logging
|
| 6 |
-
import os.path
|
| 7 |
import random
|
| 8 |
|
| 9 |
import beartype
|
| 10 |
import einops.layers.torch
|
| 11 |
import numpy as np
|
| 12 |
-
import
|
| 13 |
from jaxtyping import UInt8, jaxtyped
|
| 14 |
from PIL import Image
|
| 15 |
from torch import Tensor
|
|
@@ -17,104 +15,25 @@ from torchvision.transforms import v2
|
|
| 17 |
|
| 18 |
logger = logging.getLogger("data.py")
|
| 19 |
|
|
|
|
|
|
|
| 20 |
|
| 21 |
@beartype.beartype
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
label: str
|
| 29 |
-
target: int
|
| 30 |
-
|
| 31 |
-
samples: list[Sample]
|
| 32 |
-
|
| 33 |
-
def __init__(self, root: str, split: str):
|
| 34 |
-
self.logger = logging.getLogger("ade20k")
|
| 35 |
-
self.root = root
|
| 36 |
-
self.split = split
|
| 37 |
-
self.img_dir = os.path.join(root, "images")
|
| 38 |
-
self.seg_dir = os.path.join(root, "annotations")
|
| 39 |
-
|
| 40 |
-
# Check that we have the right path.
|
| 41 |
-
for subdir in ("images", "annotations"):
|
| 42 |
-
if not os.path.isdir(os.path.join(root, subdir)):
|
| 43 |
-
# Something is missing.
|
| 44 |
-
if os.path.realpath(root).endswith(subdir):
|
| 45 |
-
self.logger.warning(
|
| 46 |
-
"The ADE20K root should contain 'images/' and 'annotations/' directories."
|
| 47 |
-
)
|
| 48 |
-
raise ValueError(f"Can't find path '{os.path.join(root, subdir)}'.")
|
| 49 |
-
|
| 50 |
-
_, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir)
|
| 51 |
-
split_lookup: dict[int, str] = {
|
| 52 |
-
value: key for key, value in split_mapping.items()
|
| 53 |
-
}
|
| 54 |
-
self.loader = torchvision.datasets.folder.default_loader
|
| 55 |
-
|
| 56 |
-
err_msg = f"Split '{split}' not in '{set(split_lookup.values())}'."
|
| 57 |
-
assert split in set(split_lookup.values()), err_msg
|
| 58 |
-
|
| 59 |
-
# Load all the image paths.
|
| 60 |
-
imgs: list[str] = [
|
| 61 |
-
path
|
| 62 |
-
for path, s in torchvision.datasets.folder.make_dataset(
|
| 63 |
-
self.img_dir,
|
| 64 |
-
split_mapping,
|
| 65 |
-
extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
|
| 66 |
-
)
|
| 67 |
-
if split_lookup[s] == split
|
| 68 |
-
]
|
| 69 |
-
|
| 70 |
-
segs: list[str] = [
|
| 71 |
-
path
|
| 72 |
-
for path, s in torchvision.datasets.folder.make_dataset(
|
| 73 |
-
self.seg_dir,
|
| 74 |
-
split_mapping,
|
| 75 |
-
extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
|
| 76 |
-
)
|
| 77 |
-
if split_lookup[s] == split
|
| 78 |
-
]
|
| 79 |
-
|
| 80 |
-
# Load all the targets, classes and mappings
|
| 81 |
-
with open(os.path.join(root, "sceneCategories.txt")) as fd:
|
| 82 |
-
img_labels: list[str] = [line.split()[1] for line in fd.readlines()]
|
| 83 |
-
|
| 84 |
-
label_set = sorted(set(img_labels))
|
| 85 |
-
label_to_idx = {label: i for i, label in enumerate(label_set)}
|
| 86 |
-
|
| 87 |
-
self.samples = [
|
| 88 |
-
self.Sample(img_path, seg_path, label, label_to_idx[label])
|
| 89 |
-
for img_path, seg_path, label in zip(imgs, segs, img_labels)
|
| 90 |
-
]
|
| 91 |
-
|
| 92 |
-
def __getitem__(self, index: int) -> dict[str, object]:
|
| 93 |
-
# Convert to dict.
|
| 94 |
-
sample = dataclasses.asdict(self.samples[index])
|
| 95 |
-
|
| 96 |
-
sample["image"] = self.loader(sample.pop("img_path"))
|
| 97 |
-
sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L")
|
| 98 |
-
sample["index"] = index
|
| 99 |
-
|
| 100 |
-
return sample
|
| 101 |
-
|
| 102 |
-
def __len__(self) -> int:
|
| 103 |
-
return len(self.samples)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
@functools.cache
|
| 107 |
-
def get_dataset() -> Ade20k:
|
| 108 |
-
return Ade20k(
|
| 109 |
-
root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/",
|
| 110 |
-
split="validation",
|
| 111 |
-
)
|
| 112 |
|
| 113 |
|
| 114 |
@beartype.beartype
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
@jaxtyped(typechecker=beartype.beartype)
|
|
|
|
| 1 |
import base64
|
|
|
|
| 2 |
import functools
|
| 3 |
import io
|
| 4 |
import logging
|
|
|
|
| 5 |
import random
|
| 6 |
|
| 7 |
import beartype
|
| 8 |
import einops.layers.torch
|
| 9 |
import numpy as np
|
| 10 |
+
import requests
|
| 11 |
from jaxtyping import UInt8, jaxtyped
|
| 12 |
from PIL import Image
|
| 13 |
from torch import Tensor
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger("data.py")
|
| 17 |
|
| 18 |
+
R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev"
|
| 19 |
+
|
| 20 |
|
| 21 |
@beartype.beartype
|
| 22 |
+
@functools.lru_cache(maxsize=512)
|
| 23 |
+
def get_image(i: int) -> Image.Image:
|
| 24 |
+
fpath = f"/images/ADE_val_{i + 1:08}.jpg"
|
| 25 |
+
url = R2_URL + fpath
|
| 26 |
+
logger.info("Getting image from '%s'.", url)
|
| 27 |
+
return Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
@beartype.beartype
|
| 31 |
+
@functools.lru_cache(maxsize=512)
|
| 32 |
+
def get_seg(i: int) -> Image.Image:
|
| 33 |
+
fpath = f"/annotations/ADE_val_{i + 1:08}.png"
|
| 34 |
+
url = R2_URL + fpath
|
| 35 |
+
logger.info("Getting annotations from '%s'.", url)
|
| 36 |
+
return Image.open(requests.get(url, stream=True).raw)
|
| 37 |
|
| 38 |
|
| 39 |
@jaxtyped(typechecker=beartype.beartype)
|