CTM Experiments - Continuous Thought Machine Models
Experimental checkpoints trained on the Continuous Thought Machine architecture by Sakana AI.
These are community experiments on the original work - not official SakanaAI models.
Paper Reference
Continuous Thought Machines
Sakana AI
@article{sakana2025ctm,
title={Continuous Thought Machines},
author={Sakana AI},
journal={arXiv preprint arXiv:2505.05522},
year={2025}
}
Core Insight
CTM's key innovation: accuracy improves with more internal iterations. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with.
Models
| Model | File | Size | Task | Accuracy | Description |
|---|---|---|---|---|---|
| MNIST | ctm-mnist.pt |
1.3M | Digit classification | 97.9% | 10-class MNIST |
| Parity-16 | ctm-parity-16.pt |
2.5M | Cumulative parity | 99.0% | 16-bit sequences |
| Parity-64 | ctm-parity-64.pt |
66M | Cumulative parity | 58.6% | 64-bit sequences (custom config) |
| Parity-64 Official | ctm-parity-64-official.pt |
21M | Cumulative parity | 57.7% | 64-bit sequences (official config) |
| QAMNIST | ctm-qamnist.pt |
39M | Multi-step arithmetic | 100% | 3-5 digits, 3-5 ops |
| Brackets | ctm-brackets.pt |
6.1M | Bracket matching | 94.7% | Valid/invalid (()[]) |
| Tracking-Quadrant | ctm-tracking-quadrant.pt |
6.7M | Motion quadrant | 100% | 4-class prediction |
| Tracking-Position | ctm-tracking-position.pt |
6.7M | Exact position | 93.8% | 256-class (16x16 grid) |
| Transfer | ctm-transfer-parity-brackets.pt |
2.5M | Transfer learning | 94.5% | Parity core to brackets |
| Jigsaw MNIST | ctm-jigsaw-mnist.pt |
19M | Jigsaw puzzle solving | 92.3% | Reassemble 2x2 shuffled MNIST |
| Rotation MNIST | ctm-rotation-mnist.pt |
4.2M | Rotation prediction | 89.1% | Predict rotation angle (4 classes) |
| Brackets Transfer | ctm-brackets-transfer-depth4.pt |
6.1M | Transfer learning | 95.1% | Parity→Brackets (depth 4 synapse) |
| Dual-Task | ctm-dual-task-brackets-parity.pt |
2.8M | Multi-task | 86.1% | Brackets (94%) + Parity (78%) jointly |
| Parity-64 | ctm-parity-64-8x8.pt |
4.1M | Long parity | 58.6% | 64-bit (8x8) cumulative parity |
| Parity-144 | ctm-parity-144-12x12.pt |
4.1M | Long parity | 51.7% | 144-bit (12x12) cumulative parity |
Model Configurations
MNIST CTM
config = {
"iterations": 15,
"memory_length": 10,
"d_model": 128,
"d_input": 128,
"heads": 2,
"n_synch_out": 16,
"n_synch_action": 16,
"memory_hidden_dims": 8,
"out_dims": 10,
"synapse_depth": 1,
}
Parity-16 CTM
config = {
"iterations": 50,
"memory_length": 25,
"d_model": 256,
"d_input": 32,
"heads": 8,
"synapse_depth": 8,
"out_dims": 16, # cumulative parity
}
Parity-64 Official CTM
config = {
"iterations": 75,
"memory_length": 25,
"d_model": 1024,
"d_input": 64,
"heads": 8,
"n_synch_out": 32,
"n_synch_action": 32,
"synapse_depth": 1, # linear synapse (official)
"out_dims": 64, # cumulative parity
}
QAMNIST CTM
config = {
"iterations": 10,
"memory_length": 30,
"d_model": 1024,
"d_input": 64,
"synapse_depth": 1,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
}
Brackets CTM
config = {
"iterations": 30,
"memory_length": 15,
"d_model": 256,
"d_input": 64,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
"out_dims": 2, # valid/invalid
}
Tracking CTM
config = {
"iterations": 20,
"memory_length": 15,
"d_model": 256,
"d_input": 64,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
}
Jigsaw MNIST CTM
config = {
"iterations": 30,
"memory_length": 20,
"d_model": 512,
"d_input": 128,
"heads": 8,
"n_synch_out": 32,
"n_synch_action": 32,
"synapse_depth": 1,
"out_dims": 24, # 4 tiles x 6 permutation options
"backbone_type": "jigsaw",
}
Rotation MNIST CTM
config = {
"iterations": 20,
"memory_length": 15,
"d_model": 256,
"d_input": 64,
"heads": 4,
"n_synch_out": 32,
"n_synch_action": 32,
"synapse_depth": 1,
"out_dims": 4, # 0°, 90°, 180°, 270°
"backbone_type": "rotation",
}
Usage
import torch
from huggingface_hub import hf_hub_download
# Download model
model_path = hf_hub_download(
repo_id="vincentoh/ctm-experiments",
filename="ctm-mnist.pt"
)
# Load checkpoint
checkpoint = torch.load(model_path, map_location="cpu")
# Initialize CTM with matching config
from models.ctm import ContinuousThoughtMachine
model = ContinuousThoughtMachine(**config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Inference
with torch.no_grad():
output = model(input_tensor)
Training Details
- Hardware: NVIDIA RTX 4070 Ti SUPER
- Framework: PyTorch
- Optimizer: AdamW
- Training time: 5 minutes (MNIST) to 17 hours (QAMNIST)
Key Findings
- Architecture > Scale: Small sync dimensions (32) with linear synapses work better than large/deep variants
- "Thinking Longer" = Higher Accuracy: CTM accuracy improves with more internal iterations
- Transfer Learning Works: Parity-trained core transfers to brackets with 94.5% accuracy
- Architectural Limits: CTM has a ~58% ceiling on 64-bit parity regardless of hyperparameters
Parity Scaling Experiments
We tested CTM on increasingly long parity sequences to find where it breaks down:
| Sequence | Grid | Accuracy | vs Random | Status |
|---|---|---|---|---|
| 16 | 4x4 | 99.0% | +49.0% | ✅ Solved |
| 36 | 6x6 | 66.3% | +16.3% | ⚠️ Degraded |
| 64 | 8x8 | 58.6% | +8.6% | ❌ Struggling |
| 64 (official) | 8x8 | 57.7% | +7.7% | ❌ Same ceiling |
| 144 | 12x12 | 51.7% | +1.7% | ❌ Random |
Key insight: The ~58% ceiling for parity-64 is an architectural limit, not a hyperparameter issue. Both custom config (d_model=512, synapse_depth=4) and official config (d_model=1024, synapse_depth=1) achieve essentially the same accuracy.
Why CTM Fails on Long Parity
Parity requires strict sequential computation: process bit 1 before bit 2 before bit 3... CTM's attention-based "thinking" is fundamentally parallel - all positions attend simultaneously. The model can learn approximate sequential patterns for short sequences (~64 steps), but this breaks down for longer sequences.
CTM excels at:
- Moderate sequence lengths (< 64 elements)
- Local dependencies (brackets: track depth, not full history)
- Parallelizable structure (MNIST: patches contribute independently)
CTM struggles with:
- Long strict sequential dependencies (parity-144)
- Tasks requiring O(n) sequential steps where n > ~64
License
MIT License (same as original CTM repository)
Acknowledgments
- Sakana AI for the Continuous Thought Machine architecture
- Original CTM Repository