AxoNet VAE Stage 1

A variational autoencoder for semantic segmentation of neuronal morphologies from 2D projections.

Model Description

AxoNet VAE is a U-Net architecture with:

  • Variational bottleneck: Global latent space for neuron embeddings
  • Variational skip connections: Prevents information bypass around the bottleneck
  • Multi-task heads: Segmentation (6 classes) + depth estimation

Architecture

  • Base channels: 64
  • Latent channels: 128
  • Input: Single-channel grayscale (512x512)
  • Output: 6-class segmentation + depth map

Segmentation Classes

Index Class Description
0 background Non-neuron pixels
1 soma Cell body
2 axon Axonal processes
3 basal_dendrite Basal dendritic arbor
4 apical_dendrite Apical dendritic arbor
5 other Unclassified neurite

Training

Usage

Quick Start

import torch
from huggingface_hub import hf_hub_download

# Download model
model_path = hf_hub_download(
    repo_id="broadinstitute/axonet-vae-stage1",
    filename="pytorch_model.bin"
)

# Load weights (requires axonet package)
from axonet.models.d3_swc_vae import SegVAE2D

model = SegVAE2D(
    in_channels=1,
    base_channels=64,
    num_classes=6,
    latent_channels=128,
    skip_mode="variational",
)
model.load_state_dict(torch.load(model_path))
model.eval()

Inference

import torch
from PIL import Image
import numpy as np

# Load and preprocess image
img = Image.open("neuron_mask.png").convert("L")
img = img.resize((512, 512))
tensor = torch.from_numpy(np.array(img) / 255.0).float()
tensor = tensor.unsqueeze(0).unsqueeze(0)  # (1, 1, 512, 512)

# Run inference
with torch.no_grad():
    outputs = model(tensor, return_latent=True)

segmentation = outputs["seg_logits"].argmax(dim=1)  # (1, 512, 512)
depth = outputs["depth"]  # (1, 1, 512, 512)
embedding = outputs["mu"].mean(dim=(2, 3))  # (1, 128) - latent embedding

Extract Embeddings

# Get neuron embedding for downstream tasks
with torch.no_grad():
    z, mu, logvar, _, _, _ = model.encode(tensor)
    embedding = mu.mean(dim=(2, 3))  # Global average pooling

Files

File Description
pytorch_model.bin PyTorch state dict
model.safetensors Safetensors format (recommended)
config.json Model configuration
full_checkpoint/best.ckpt Full Lightning checkpoint

Downstream Model

This model serves as the encoder for:

Citation

@misc{axonet2025,
  author = {Hall, Giles},
  title = {AxoNet: Multimodal Neuron Morphology Embeddings via 2D Projections},
  year = {2025},
  publisher = {HuggingFace},
  howpublished = {\url{https://huggingface.co/broadinstitute/axonet-vae-stage1}}
}

License

MIT License

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train broadinstitute/axonet-vae-stage1