AxoNet VAE Stage 1
A variational autoencoder for semantic segmentation of neuronal morphologies from 2D projections.
Model Description
AxoNet VAE is a U-Net architecture with:
- Variational bottleneck: Global latent space for neuron embeddings
- Variational skip connections: Prevents information bypass around the bottleneck
- Multi-task heads: Segmentation (6 classes) + depth estimation
Architecture
- Base channels: 64
- Latent channels: 128
- Input: Single-channel grayscale (512x512)
- Output: 6-class segmentation + depth map
Segmentation Classes
| Index | Class | Description |
|---|---|---|
| 0 | background | Non-neuron pixels |
| 1 | soma | Cell body |
| 2 | axon | Axonal processes |
| 3 | basal_dendrite | Basal dendritic arbor |
| 4 | apical_dendrite | Apical dendritic arbor |
| 5 | other | Unclassified neurite |
Training
- Dataset: broadinstitute/axonet-neuromorpho-dataset
- Neurons: 7,158 (curated from NeuroMorpho.org)
- Images: ~164K multi-view renderings
- Epochs: 22
- Batch size: 32
- Hardware: 2x NVIDIA A100 80GB
Usage
Quick Start
import torch
from huggingface_hub import hf_hub_download
# Download model
model_path = hf_hub_download(
repo_id="broadinstitute/axonet-vae-stage1",
filename="pytorch_model.bin"
)
# Load weights (requires axonet package)
from axonet.models.d3_swc_vae import SegVAE2D
model = SegVAE2D(
in_channels=1,
base_channels=64,
num_classes=6,
latent_channels=128,
skip_mode="variational",
)
model.load_state_dict(torch.load(model_path))
model.eval()
Inference
import torch
from PIL import Image
import numpy as np
# Load and preprocess image
img = Image.open("neuron_mask.png").convert("L")
img = img.resize((512, 512))
tensor = torch.from_numpy(np.array(img) / 255.0).float()
tensor = tensor.unsqueeze(0).unsqueeze(0) # (1, 1, 512, 512)
# Run inference
with torch.no_grad():
outputs = model(tensor, return_latent=True)
segmentation = outputs["seg_logits"].argmax(dim=1) # (1, 512, 512)
depth = outputs["depth"] # (1, 1, 512, 512)
embedding = outputs["mu"].mean(dim=(2, 3)) # (1, 128) - latent embedding
Extract Embeddings
# Get neuron embedding for downstream tasks
with torch.no_grad():
z, mu, logvar, _, _, _ = model.encode(tensor)
embedding = mu.mean(dim=(2, 3)) # Global average pooling
Files
| File | Description |
|---|---|
pytorch_model.bin |
PyTorch state dict |
model.safetensors |
Safetensors format (recommended) |
config.json |
Model configuration |
full_checkpoint/best.ckpt |
Full Lightning checkpoint |
Downstream Model
This model serves as the encoder for:
- broadinstitute/axonet-clip-stage2 - CLIP model for text-image retrieval
Citation
@misc{axonet2025,
author = {Hall, Giles},
title = {AxoNet: Multimodal Neuron Morphology Embeddings via 2D Projections},
year = {2025},
publisher = {HuggingFace},
howpublished = {\url{https://huggingface.co/broadinstitute/axonet-vae-stage1}}
}
License
MIT License
- Downloads last month
- 1