YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

CTM Experiments - Continuous Thought Machine Models

Experimental checkpoints trained on the Continuous Thought Machine architecture by Sakana AI.

These are community experiments on the original work - not official SakanaAI models.

Paper Reference

Continuous Thought Machines

Sakana AI

arXiv:2505.05522

Interactive Demo | Blog Post

@article{sakana2025ctm,
  title={Continuous Thought Machines},
  author={Sakana AI},
  journal={arXiv preprint arXiv:2505.05522},
  year={2025}
}

Core Insight

CTM's key innovation: accuracy improves with more internal iterations. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with.

Models

Model File Size Task Accuracy Description
MNIST ctm-mnist.pt 1.3M Digit classification 97.9% 10-class MNIST
Parity-16 ctm-parity-16.pt 2.5M Cumulative parity 99.0% 16-bit sequences
Parity-64 ctm-parity-64.pt 66M Cumulative parity 58.6% 64-bit sequences (custom config)
Parity-64 Official ctm-parity-64-official.pt 21M Cumulative parity 57.7% 64-bit sequences (official config)
QAMNIST ctm-qamnist.pt 39M Multi-step arithmetic 100% 3-5 digits, 3-5 ops
Brackets ctm-brackets.pt 6.1M Bracket matching 94.7% Valid/invalid (()[])
Tracking-Quadrant ctm-tracking-quadrant.pt 6.7M Motion quadrant 100% 4-class prediction
Tracking-Position ctm-tracking-position.pt 6.7M Exact position 93.8% 256-class (16x16 grid)
Transfer ctm-transfer-parity-brackets.pt 2.5M Transfer learning 94.5% Parity core to brackets
Jigsaw MNIST ctm-jigsaw-mnist.pt 19M Jigsaw puzzle solving 92.3% Reassemble 2x2 shuffled MNIST
Rotation MNIST ctm-rotation-mnist.pt 4.2M Rotation prediction 89.1% Predict rotation angle (4 classes)
Brackets Transfer ctm-brackets-transfer-depth4.pt 6.1M Transfer learning 95.1% Parity→Brackets (depth 4 synapse)
Dual-Task ctm-dual-task-brackets-parity.pt 2.8M Multi-task 86.1% Brackets (94%) + Parity (78%) jointly
Parity-64 ctm-parity-64-8x8.pt 4.1M Long parity 58.6% 64-bit (8x8) cumulative parity
Parity-144 ctm-parity-144-12x12.pt 4.1M Long parity 51.7% 144-bit (12x12) cumulative parity

Model Configurations

MNIST CTM

config = {
    "iterations": 15,
    "memory_length": 10,
    "d_model": 128,
    "d_input": 128,
    "heads": 2,
    "n_synch_out": 16,
    "n_synch_action": 16,
    "memory_hidden_dims": 8,
    "out_dims": 10,
    "synapse_depth": 1,
}

Parity-16 CTM

config = {
    "iterations": 50,
    "memory_length": 25,
    "d_model": 256,
    "d_input": 32,
    "heads": 8,
    "synapse_depth": 8,
    "out_dims": 16,  # cumulative parity
}

Parity-64 Official CTM

config = {
    "iterations": 75,
    "memory_length": 25,
    "d_model": 1024,
    "d_input": 64,
    "heads": 8,
    "n_synch_out": 32,
    "n_synch_action": 32,
    "synapse_depth": 1,  # linear synapse (official)
    "out_dims": 64,  # cumulative parity
}

QAMNIST CTM

config = {
    "iterations": 10,
    "memory_length": 30,
    "d_model": 1024,
    "d_input": 64,
    "synapse_depth": 1,
    "heads": 4,
    "n_synch_out": 32,
    "n_synch_action": 32,
}

Brackets CTM

config = {
    "iterations": 30,
    "memory_length": 15,
    "d_model": 256,
    "d_input": 64,
    "heads": 4,
    "n_synch_out": 32,
    "n_synch_action": 32,
    "out_dims": 2,  # valid/invalid
}

Tracking CTM

config = {
    "iterations": 20,
    "memory_length": 15,
    "d_model": 256,
    "d_input": 64,
    "heads": 4,
    "n_synch_out": 32,
    "n_synch_action": 32,
}

Jigsaw MNIST CTM

config = {
    "iterations": 30,
    "memory_length": 20,
    "d_model": 512,
    "d_input": 128,
    "heads": 8,
    "n_synch_out": 32,
    "n_synch_action": 32,
    "synapse_depth": 1,
    "out_dims": 24,  # 4 tiles x 6 permutation options
    "backbone_type": "jigsaw",
}

Rotation MNIST CTM

config = {
    "iterations": 20,
    "memory_length": 15,
    "d_model": 256,
    "d_input": 64,
    "heads": 4,
    "n_synch_out": 32,
    "n_synch_action": 32,
    "synapse_depth": 1,
    "out_dims": 4,  # 0°, 90°, 180°, 270°
    "backbone_type": "rotation",
}

Usage

import torch
from huggingface_hub import hf_hub_download

# Download model
model_path = hf_hub_download(
    repo_id="vincentoh/ctm-experiments",
    filename="ctm-mnist.pt"
)

# Load checkpoint
checkpoint = torch.load(model_path, map_location="cpu")

# Initialize CTM with matching config
from models.ctm import ContinuousThoughtMachine

model = ContinuousThoughtMachine(**config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Inference
with torch.no_grad():
    output = model(input_tensor)

Training Details

  • Hardware: NVIDIA RTX 4070 Ti SUPER
  • Framework: PyTorch
  • Optimizer: AdamW
  • Training time: 5 minutes (MNIST) to 17 hours (QAMNIST)

Key Findings

  1. Architecture > Scale: Small sync dimensions (32) with linear synapses work better than large/deep variants
  2. "Thinking Longer" = Higher Accuracy: CTM accuracy improves with more internal iterations
  3. Transfer Learning Works: Parity-trained core transfers to brackets with 94.5% accuracy
  4. Architectural Limits: CTM has a ~58% ceiling on 64-bit parity regardless of hyperparameters

Parity Scaling Experiments

We tested CTM on increasingly long parity sequences to find where it breaks down:

Sequence Grid Accuracy vs Random Status
16 4x4 99.0% +49.0% ✅ Solved
36 6x6 66.3% +16.3% ⚠️ Degraded
64 8x8 58.6% +8.6% ❌ Struggling
64 (official) 8x8 57.7% +7.7% ❌ Same ceiling
144 12x12 51.7% +1.7% ❌ Random

Key insight: The ~58% ceiling for parity-64 is an architectural limit, not a hyperparameter issue. Both custom config (d_model=512, synapse_depth=4) and official config (d_model=1024, synapse_depth=1) achieve essentially the same accuracy.

Why CTM Fails on Long Parity

Parity requires strict sequential computation: process bit 1 before bit 2 before bit 3... CTM's attention-based "thinking" is fundamentally parallel - all positions attend simultaneously. The model can learn approximate sequential patterns for short sequences (~64 steps), but this breaks down for longer sequences.

CTM excels at:

  • Moderate sequence lengths (< 64 elements)
  • Local dependencies (brackets: track depth, not full history)
  • Parallelizable structure (MNIST: patches contribute independently)

CTM struggles with:

  • Long strict sequential dependencies (parity-144)
  • Tasks requiring O(n) sequential steps where n > ~64

License

MIT License (same as original CTM repository)

Acknowledgments

Links

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support