TAP-CT: 3D Task-Agnostic Pretraining of CT Foundation Models

TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.

This repository provides TAP-CT-B-3D, a Vision Transformer (ViT-Base) architecture pretrained on volumetric inputs with a spatial resolution of (12, 224, 224) and a patch size of (4, 8, 8). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan. Additional TAP-CT model variants, as well as the image processor, will be released in future updates.

Preprocessing

While a dedicated image processor will be released in future updates, optimal feature extraction requires the following preprocessing pipeline:

  1. Orientation: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
  2. Spatial Resizing: Resize the volume to a spatial resolution of (z, 224, 224) or (z, 512, 512), where (z) represents the number of slices along the axial dimension.
  3. Axial Padding: Apply zero-padding along the (z)-axis to ensure divisibility by 4, accommodating the model's patch size of (4, 8, 8).
  4. Intensity Clipping: Clip voxel intensities to the range ([-1008, 822]) HU (Hounsfield Units).
  5. Normalization: Apply z-score normalization using (mean = -86.8086) and (std = 322.6347).

Usage

import torch
from transformers import AutoModel

# Load the model
model = AutoModel.from_pretrained('fomofo/tap-ct-b-3d', trust_remote_code=True)

# Prepare input (batch_size, channels, depth, height, width)
x = torch.randn((16, 1, 12, 224, 224))

# Forward pass
output = model.forward(x)

The model returns a BaseModelOutputWithPooling object from the transformers library. The output.pooler_output contains the pooled [CLS] token representation, while output.last_hidden_state contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass output_hidden_states=True to the forward method.

Model Details

  • Model Type: 3D CT Vision Foundation Model
  • Input Shape: (batch_size, 1, depth, height, width)
  • Example Input: (16, 1, 12, 224, 224) - batch of 16 CT crops with 12 slices at 224ร—224 resolution
  • License: CC-BY-NC-4.0
Downloads last month
52
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support