vincentoh commited on
Commit
63d759d
·
verified ·
1 Parent(s): 0480e79

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +212 -159
README.md CHANGED
@@ -1,197 +1,250 @@
1
- # CTM Experiments
2
 
3
- Personal experiments with [Continuous Thought Machines](https://github.com/SakanaAI/continuous-thought-machines) (SakanaAI).
4
 
5
- **Interactive Demo**: https://pub.sakana.ai/ctm/
6
 
7
- ## Core Insight: Thinking Takes Time
8
 
9
- CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers.
 
 
 
 
 
 
10
 
11
- This enables CTM to learn algorithmic reasoning that feedforward networks struggle with:
12
-
13
- | Task | Challenge | What CTM Learns |
14
- |------|-----------|-----------------|
15
- | **Parity** | Count bits across sequence | Iterative accumulation |
16
- | **Brackets** | Track nested structure | Stack-like memory (LIFO) |
17
- | **Object Tracking** | Extrapolate motion | Physics simulation |
18
- | **Mazes** | Navigate 2D paths | Sequential decision making |
19
- | **Jigsaw** | Classify shuffled patches | Part-whole integration |
20
-
21
- ## Results Summary
22
-
23
- | Experiment | Accuracy | Notes |
24
- |------------|----------|-------|
25
- | **MNIST** | **97.9%** | Digit classification, 5 min training |
26
- | **Parity-16** | **99.0%** | 16-bit cumulative parity |
27
- | **QAMNIST** | **100%** | Multi-step arithmetic (3-5 digits, 3-5 ops) |
28
- | **Brackets** | **94.7%** | Stack-like reasoning for `(()[])` vs `([)]` |
29
- | **Object Tracking** | **100%** | Quadrant prediction from motion (4 classes) |
30
- | **Velocity Prediction** | **100%** | Direction prediction (9 classes) |
31
- | **Position Prediction** | **93.8%** | Exact position (256 classes, 16x16 grid) |
32
- | **Transfer Learning** | **94.5%** | Parity→Brackets (core frozen) |
33
- | **Maze Solving** | **Visualized** | Pretrained model inference on 15x15 mazes |
34
- | **Jigsaw MNIST** | **92%** | Classify digits from shuffled patches (no positional encoding) |
35
-
36
- ## Key Findings
37
-
38
- ### 1. Architecture Matters More Than Scale
39
-
40
- Early experiments showed 50% accuracy on parity (random guessing). The fix wasn't more parameters - it was using the **correct architecture**:
41
-
42
- | Parameter | Wrong | Correct (Official) |
43
- |-----------|-------|-------------------|
44
- | `n_synch_out` | 512 | **32** |
45
- | `n_synch_action` | 512 | **32** |
46
- | `synapse_depth` | 4 (U-NET) | **1** (linear) |
47
-
48
- The official parity implementation uses surprisingly small synchronization dimensions with a linear synapse - this is critical for learning.
49
-
50
- ### 2. "Thinking Longer" = Higher Accuracy
51
-
52
- ![MNIST Inference per Tick](continuous-thought-machines/experiments/results/mnist_inference.png)
53
-
54
- CTM accuracy improves with more internal iterations:
55
- - **Tick 0**: 7% (random)
56
- - **Tick 10-11**: 100% (peak)
57
- - **Final tick**: 98%
58
-
59
- Harder tasks need more "thinking time" - parity peaks at tick 35.
60
-
61
- ### 3. Transfer Learning Works
62
-
63
- Pretrained parity model transfers to brackets:
64
- - **Baseline**: 52.5% (random)
65
- - **After transfer**: 94.5% (core frozen, only backbone/output trained)
66
-
67
- The iterative counting learned for parity transfers to stack tracking for brackets - matching from-scratch performance with only 37.7% of parameters trainable.
68
-
69
- ### 4. Maze Solving "The Hard Way"
70
-
71
- CTM solves mazes by outputting action trajectories (Up/Down/Left/Right/Wait), not pixel masks:
72
- - **Step accuracy**: 60%+ after 2000 iterations
73
- - Uses auto-extending curriculum (loss only on trajectory up to first error)
74
- - Demonstrates sequential reasoning capability
75
-
76
- ![Maze Attention Overlay](continuous-thought-machines/experiments/results/maze_attention.gif)
77
-
78
- *CTM "thinking" through a 15x15 maze: blue = predicted path, red = attention focus, green = start position. The attention heatmap shows where CTM looks at each internal tick (T=75 iterations).*
79
-
80
- ## Detailed Results
81
-
82
- ### MNIST Digit Classification (97.9%)
83
-
84
- ![MNIST Training Accuracy](continuous-thought-machines/experiments/results/mnist-ctm_smoothed.png)
85
-
86
- CTM learns digit classification in ~5 minutes on RTX 4070 Ti.
87
-
88
- ### Parity-16 Cumulative Parity (99.0%)
89
-
90
- ![Parity Inference per Tick](continuous-thought-machines/experiments/results/parity_inference.png)
91
 
92
- 16-bit parity with cumulative outputs - harder task shows clearer "thinking" benefit.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- ### QAMNIST Multi-Step Arithmetic (100%)
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- ![QAMNIST Training Accuracy](continuous-thought-machines/experiments/results/qamnist-ctm-10_smoothed.png)
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- 100% accuracy on multi-step arithmetic (3-5 MNIST digits, 3-5 operations) after 300k iterations.
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- ### Maze Navigation (Pretrained Model)
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- Using the authors' pretrained checkpoint (`ctm_mazeslarge_D=2048_T=75_M=25.pt`), we ran inference on the small-mazes dataset:
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- - **Model**: D=2048 neurons, T=75 thinking steps, M=25 max trajectory length
105
- - **Dataset**: 1000 test mazes (15x15 grid)
106
- - **Output**: Action trajectories (Up/Down/Left/Right/Wait)
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- The visualization shows CTM's attention patterns as it navigates:
109
- 1. **Red heatmap**: Where CTM "looks" at each thinking step
110
- 2. **Blue path**: Predicted solution trajectory
111
- 3. **Green marker**: Start position
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- Key insight: CTM learns sequential decision-making through iterative internal computation, not memorization.
114
 
115
- ### Object Tracking - Position Prediction (93.8%)
 
 
116
 
117
- ![Position Tracking Training](continuous-thought-machines/experiments/results/tracking_position.png)
 
 
 
 
118
 
119
- The hardest tracking task: predict exact cell (256 classes) from 5 frames of motion. CTM reaches 93.8% test accuracy, demonstrating temporal reasoning across video frames.
 
120
 
121
- ## Experiment Tracking
 
122
 
123
- - **Configs**: [`experiments/experiments.json`](continuous-thought-machines/experiments/experiments.json)
124
- - **Training Scripts**: [`experiments/training/`](continuous-thought-machines/experiments/training/)
125
- - **Inference Scripts**: [`experiments/inference/`](continuous-thought-machines/experiments/inference/)
126
- - **Results**: [`experiments/results/`](continuous-thought-machines/experiments/results/)
127
 
128
- ## Custom Experiments
 
 
 
129
 
130
- ### Bracket Matching
131
- Classify bracket strings as valid or invalid: `(()[])` vs `([)]`
132
 
133
- Requires tracking nested depth and bracket types - implementing a stack through iterative thinking.
 
 
 
134
 
135
- ### Object Tracking
136
- Predict properties of a moving dot from 5 video frames (16x16 grid).
137
 
138
- ```
139
- Frame 0 Frame 1 Frame 2 Frame 3 Frame 4
140
- . . . . . . . . . . . . . . . . . . . .
141
- . * . . . . * . . . . * . . . . . . . .
142
- . . . . . . . . . . . . . . . * . . . .
143
- . . . . . . . . . . . . . . . . . . . *
144
- ```
145
 
146
- Three prediction tasks tested:
147
- | Task | Classes | Accuracy | Notes |
148
- |------|---------|----------|-------|
149
- | **Quadrant** | 4 | 100% | TL/TR/BL/BR - easiest |
150
- | **Velocity** | 9 | 100% | 8 directions + stationary |
151
- | **Position** | 256 | 93.8% | Exact cell (16x16) - hardest |
152
 
153
- All tasks converged, demonstrating CTM's ability to learn temporal/spatial reasoning.
154
 
155
- ### Transfer Learning
156
- Freeze core CTM dynamics from parity-16, train only backbone/output for brackets.
 
 
 
 
 
157
 
158
- ### Maze Inference
159
- Run pretrained maze model on small-mazes dataset to visualize CTM's "thinking" process:
160
 
161
- ```bash
162
- python -m tasks.mazes.analysis.run \
163
- --actions viz \
164
- --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt \
165
- --dataset_for_viz small-mazes
166
- ```
167
 
168
- Outputs attention overlay GIFs to `tasks/mazes/analysis/outputs/viz/`.
169
 
170
- ### Jigsaw MNIST
171
- Classify MNIST digits from **randomly shuffled patches** without positional encoding.
 
 
172
 
173
- ```
174
- Original: Shuffled (input):
175
- ┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐
176
- │ 1 │ 2 │ 3 │ 4 │ │12 │ 7 │ 2 │15 │
177
- ├───┼───┼───┼───┤ ├───┼───┼───┼───┤
178
- │ 5 │ 6 │ 7 │ 8 │ => │ 4 │11 │ 9 │ 1 │
179
- ├───┼───┼───┼───┤ ├───┼───┼───┼───┤
180
- │ 9 │10 │11 │12 │ │ 6 │ 3 │14 │ 5 │
181
- ├───┼───┼───┼───┤ ├───┼───┼───┼───┤
182
- │13 │14 │15 │16 │ │16 │ 8 │10 │13 │
183
- └───┴─��─┴───┴───┘ └───┴───┴───┴───┘
184
- ```
185
 
186
- **Task**: Given 16 shuffled 7x7 patches, predict the digit class (0-9).
187
 
188
- **Challenge**: No positional encoding - CTM must learn to recognize digit parts and integrate them correctly through its internal synchronization dynamics.
189
 
190
- **Result**: **92% test accuracy** - CTM successfully learns part-whole relationships without explicit position information.
191
 
192
- ![Jigsaw Training](continuous-thought-machines/experiments/results/jigsaw_training.png)
 
193
 
194
- ## Resources
195
 
196
- - [CTM Paper](2505.05522v4.pdf)
197
- - [Original SakanaAI Repo](https://github.com/SakanaAI/continuous-thought-machines)
 
 
1
+ # CTM Experiments - Continuous Thought Machine Models
2
 
3
+ Experimental checkpoints trained on the [Continuous Thought Machine](https://github.com/SakanaAI/continuous-thought-machines) architecture by Sakana AI.
4
 
5
+ **These are community experiments on the original work - not official SakanaAI models.**
6
 
7
+ ## Paper Reference
8
 
9
+ > **Continuous Thought Machines**
10
+ >
11
+ > Sakana AI
12
+ >
13
+ > [arXiv:2505.05522](https://arxiv.org/abs/2505.05522)
14
+ >
15
+ > [Interactive Demo](https://pub.sakana.ai/ctm/) | [Blog Post](https://sakana.ai/ctm/)
16
 
17
+ ```bibtex
18
+ @article{sakana2025ctm,
19
+ title={Continuous Thought Machines},
20
+ author={Sakana AI},
21
+ journal={arXiv preprint arXiv:2505.05522},
22
+ year={2025}
23
+ }
24
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ ## Core Insight
27
+
28
+ CTM's key innovation: **accuracy improves with more internal iterations**. The model "thinks longer" to reach better answers. This enables CTM to learn algorithmic reasoning that feedforward networks struggle with.
29
+
30
+ ## Models
31
+
32
+ | Model | File | Size | Task | Accuracy | Description |
33
+ |-------|------|------|------|----------|-------------|
34
+ | MNIST | `ctm-mnist.pt` | 1.3M | Digit classification | 97.9% | 10-class MNIST |
35
+ | Parity-16 | `ctm-parity-16.pt` | 2.5M | Cumulative parity | 99.0% | 16-bit sequences |
36
+ | Parity-64 | `ctm-parity-64.pt` | 66M | Cumulative parity | 58.6% | 64-bit sequences (custom config) |
37
+ | Parity-64 Official | `ctm-parity-64-official.pt` | 21M | Cumulative parity | 57.7% | 64-bit sequences (official config) |
38
+ | QAMNIST | `ctm-qamnist.pt` | 39M | Multi-step arithmetic | 100% | 3-5 digits, 3-5 ops |
39
+ | Brackets | `ctm-brackets.pt` | 6.1M | Bracket matching | 94.7% | Valid/invalid `(()[])` |
40
+ | Tracking-Quadrant | `ctm-tracking-quadrant.pt` | 6.7M | Motion quadrant | 100% | 4-class prediction |
41
+ | Tracking-Position | `ctm-tracking-position.pt` | 6.7M | Exact position | 93.8% | 256-class (16x16 grid) |
42
+ | Transfer | `ctm-transfer-parity-brackets.pt` | 2.5M | Transfer learning | 94.5% | Parity core to brackets |
43
+ | Jigsaw MNIST | `ctm-jigsaw-mnist.pt` | 19M | Jigsaw puzzle solving | 92.3% | Reassemble 2x2 shuffled MNIST |
44
+ | Rotation MNIST | `ctm-rotation-mnist.pt` | 4.2M | Rotation prediction | 89.1% | Predict rotation angle (4 classes) |
45
+ | Brackets Transfer | `ctm-brackets-transfer-depth4.pt` | 6.1M | Transfer learning | 95.1% | Parity→Brackets (depth 4 synapse) |
46
+ | Dual-Task | `ctm-dual-task-brackets-parity.pt` | 2.8M | Multi-task | 86.1% | Brackets (94%) + Parity (78%) jointly |
47
+ | Parity-64 | `ctm-parity-64-8x8.pt` | 4.1M | Long parity | 58.6% | 64-bit (8x8) cumulative parity |
48
+ | Parity-144 | `ctm-parity-144-12x12.pt` | 4.1M | Long parity | 51.7% | 144-bit (12x12) cumulative parity |
49
+
50
+ ## Model Configurations
51
+
52
+ ### MNIST CTM
53
+ ```python
54
+ config = {
55
+ "iterations": 15,
56
+ "memory_length": 10,
57
+ "d_model": 128,
58
+ "d_input": 128,
59
+ "heads": 2,
60
+ "n_synch_out": 16,
61
+ "n_synch_action": 16,
62
+ "memory_hidden_dims": 8,
63
+ "out_dims": 10,
64
+ "synapse_depth": 1,
65
+ }
66
+ ```
67
 
68
+ ### Parity-16 CTM
69
+ ```python
70
+ config = {
71
+ "iterations": 50,
72
+ "memory_length": 25,
73
+ "d_model": 256,
74
+ "d_input": 32,
75
+ "heads": 8,
76
+ "synapse_depth": 8,
77
+ "out_dims": 16, # cumulative parity
78
+ }
79
+ ```
80
 
81
+ ### Parity-64 Official CTM
82
+ ```python
83
+ config = {
84
+ "iterations": 75,
85
+ "memory_length": 25,
86
+ "d_model": 1024,
87
+ "d_input": 64,
88
+ "heads": 8,
89
+ "n_synch_out": 32,
90
+ "n_synch_action": 32,
91
+ "synapse_depth": 1, # linear synapse (official)
92
+ "out_dims": 64, # cumulative parity
93
+ }
94
+ ```
95
 
96
+ ### QAMNIST CTM
97
+ ```python
98
+ config = {
99
+ "iterations": 10,
100
+ "memory_length": 30,
101
+ "d_model": 1024,
102
+ "d_input": 64,
103
+ "synapse_depth": 1,
104
+ "heads": 4,
105
+ "n_synch_out": 32,
106
+ "n_synch_action": 32,
107
+ }
108
+ ```
109
 
110
+ ### Brackets CTM
111
+ ```python
112
+ config = {
113
+ "iterations": 30,
114
+ "memory_length": 15,
115
+ "d_model": 256,
116
+ "d_input": 64,
117
+ "heads": 4,
118
+ "n_synch_out": 32,
119
+ "n_synch_action": 32,
120
+ "out_dims": 2, # valid/invalid
121
+ }
122
+ ```
123
 
124
+ ### Tracking CTM
125
+ ```python
126
+ config = {
127
+ "iterations": 20,
128
+ "memory_length": 15,
129
+ "d_model": 256,
130
+ "d_input": 64,
131
+ "heads": 4,
132
+ "n_synch_out": 32,
133
+ "n_synch_action": 32,
134
+ }
135
+ ```
136
 
137
+ ### Jigsaw MNIST CTM
138
+ ```python
139
+ config = {
140
+ "iterations": 30,
141
+ "memory_length": 20,
142
+ "d_model": 512,
143
+ "d_input": 128,
144
+ "heads": 8,
145
+ "n_synch_out": 32,
146
+ "n_synch_action": 32,
147
+ "synapse_depth": 1,
148
+ "out_dims": 24, # 4 tiles x 6 permutation options
149
+ "backbone_type": "jigsaw",
150
+ }
151
+ ```
152
 
153
+ ### Rotation MNIST CTM
154
+ ```python
155
+ config = {
156
+ "iterations": 20,
157
+ "memory_length": 15,
158
+ "d_model": 256,
159
+ "d_input": 64,
160
+ "heads": 4,
161
+ "n_synch_out": 32,
162
+ "n_synch_action": 32,
163
+ "synapse_depth": 1,
164
+ "out_dims": 4, # 0°, 90°, 180°, 270°
165
+ "backbone_type": "rotation",
166
+ }
167
+ ```
168
 
169
+ ## Usage
170
 
171
+ ```python
172
+ import torch
173
+ from huggingface_hub import hf_hub_download
174
 
175
+ # Download model
176
+ model_path = hf_hub_download(
177
+ repo_id="vincentoh/ctm-experiments",
178
+ filename="ctm-mnist.pt"
179
+ )
180
 
181
+ # Load checkpoint
182
+ checkpoint = torch.load(model_path, map_location="cpu")
183
 
184
+ # Initialize CTM with matching config
185
+ from models.ctm import ContinuousThoughtMachine
186
 
187
+ model = ContinuousThoughtMachine(**config)
188
+ model.load_state_dict(checkpoint['model_state_dict'])
189
+ model.eval()
 
190
 
191
+ # Inference
192
+ with torch.no_grad():
193
+ output = model(input_tensor)
194
+ ```
195
 
196
+ ## Training Details
 
197
 
198
+ - **Hardware**: NVIDIA RTX 4070 Ti SUPER
199
+ - **Framework**: PyTorch
200
+ - **Optimizer**: AdamW
201
+ - **Training time**: 5 minutes (MNIST) to 17 hours (QAMNIST)
202
 
203
+ ## Key Findings
 
204
 
205
+ 1. **Architecture > Scale**: Small sync dimensions (32) with linear synapses work better than large/deep variants
206
+ 2. **"Thinking Longer" = Higher Accuracy**: CTM accuracy improves with more internal iterations
207
+ 3. **Transfer Learning Works**: Parity-trained core transfers to brackets with 94.5% accuracy
208
+ 4. **Architectural Limits**: CTM has a ~58% ceiling on 64-bit parity regardless of hyperparameters
 
 
 
209
 
210
+ ## Parity Scaling Experiments
 
 
 
 
 
211
 
212
+ We tested CTM on increasingly long parity sequences to find where it breaks down:
213
 
214
+ | Sequence | Grid | Accuracy | vs Random | Status |
215
+ |----------|------|----------|-----------|--------|
216
+ | 16 | 4x4 | **99.0%** | +49.0% | ✅ Solved |
217
+ | 36 | 6x6 | **66.3%** | +16.3% | ⚠️ Degraded |
218
+ | 64 | 8x8 | **58.6%** | +8.6% | ❌ Struggling |
219
+ | 64 (official) | 8x8 | **57.7%** | +7.7% | ❌ Same ceiling |
220
+ | 144 | 12x12 | **51.7%** | +1.7% | ❌ Random |
221
 
222
+ **Key insight**: The ~58% ceiling for parity-64 is an **architectural limit**, not a hyperparameter issue. Both custom config (d_model=512, synapse_depth=4) and official config (d_model=1024, synapse_depth=1) achieve essentially the same accuracy.
 
223
 
224
+ ### Why CTM Fails on Long Parity
 
 
 
 
 
225
 
226
+ Parity requires **strict sequential computation**: process bit 1 before bit 2 before bit 3... CTM's attention-based "thinking" is fundamentally parallel - all positions attend simultaneously. The model can learn approximate sequential patterns for short sequences (~64 steps), but this breaks down for longer sequences.
227
 
228
+ **CTM excels at:**
229
+ - Moderate sequence lengths (< 64 elements)
230
+ - Local dependencies (brackets: track depth, not full history)
231
+ - Parallelizable structure (MNIST: patches contribute independently)
232
 
233
+ **CTM struggles with:**
234
+ - Long strict sequential dependencies (parity-144)
235
+ - Tasks requiring O(n) sequential steps where n > ~64
 
 
 
 
 
 
 
 
 
236
 
237
+ ## License
238
 
239
+ MIT License (same as original CTM repository)
240
 
241
+ ## Acknowledgments
242
 
243
+ - [Sakana AI](https://sakana.ai/) for the Continuous Thought Machine architecture
244
+ - Original [CTM Repository](https://github.com/SakanaAI/continuous-thought-machines)
245
 
246
+ ## Links
247
 
248
+ - [Experiment Repository](https://github.com/bigsnarfdude/ctm-experiments)
249
+ - [Original Paper](https://arxiv.org/abs/2505.05522)
250
+ - [Interactive Demo](https://pub.sakana.ai/ctm/)