Eliciting Latent Predictions from Transformers with the Tuned Lens
Paper
•
2303.08112
•
Published
A trained Tuned Lens for the Llamba-1B model (Mamba-based SSM language model).
The Tuned Lens (Belrose et al., 2023) trains affine probes at each layer to project intermediate hidden states to vocabulary space. This enables visualization of how predictions evolve layer-by-layer during the forward pass.
For Mamba/SSM models, this reveals how the recurrent state refinement progressively builds the final prediction - early layers show high entropy (uncertainty), while later layers converge to the final output.
| Parameter | Value |
|---|---|
| Base Model | cartesia-ai/Llamba-1B |
| Architecture | Mamba (SSM) |
| Layers | 16 |
| Hidden Dim | 2048 |
| Vocab Size | 128,256 |
| Training Data | WikiText-2 (1644 samples) |
| Training | 1 epoch |
| Metric | Value |
|---|---|
| Initial Loss | ~3.84 |
| Final Loss | ~1.43 |
| Layer 12 Loss | 0.775 |
| Layer 13 Loss | 0.714 |
| Layer 14 Loss | 0.443 |
| Layer 15 Loss | 0.203 |
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from cartesia_pytorch.Llamba import LlambaLMHeadModel
import json
# Download files
lens_path = hf_hub_download("Xeiroh/llamba-1b-tuned-lens", "lens.pt")
config_path = hf_hub_download("Xeiroh/llamba-1b-tuned-lens", "config.json")
# Load config
with open(config_path) as f:
config = json.load(f)
# Load base model
model = LlambaLMHeadModel.from_pretrained(
"cartesia-ai/Llamba-1B",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).cuda().eval()
# Create lens module
class MambaTunedLens(nn.Module):
def __init__(self, d_model, vocab_size, num_layers, bias=True,
unembed_weight=None, final_norm_weight=None, final_norm_eps=1e-5):
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.num_layers = num_layers
self.translators = nn.ModuleList([
nn.Linear(d_model, d_model, bias=bias) for _ in range(num_layers)
])
if unembed_weight is not None:
self.register_buffer("unembed", unembed_weight.clone())
if final_norm_weight is not None:
self.register_buffer("final_norm_weight", final_norm_weight.clone())
self.final_norm_eps = final_norm_eps
def forward_layer(self, hidden_state, layer_idx):
h = self.translators[layer_idx](hidden_state)
h = h * torch.rsqrt(h.pow(2).mean(-1, keepdim=True) + self.final_norm_eps)
h = h * self.final_norm_weight
return h @ self.unembed.T
# Initialize lens with model weights
unembed_weight = model.lm_head.weight.data
final_norm_weight = model.backbone.final_layernorm.weight.data
final_norm_eps = model.backbone.final_layernorm.variance_epsilon
lens = MambaTunedLens(
d_model=config["d_model"],
vocab_size=config["vocab_size"],
num_layers=config["num_layers"],
bias=config["bias"],
unembed_weight=unembed_weight,
final_norm_weight=final_norm_weight,
final_norm_eps=final_norm_eps,
).cuda()
# Load trained weights
lens.load_state_dict(torch.load(lens_path, weights_only=True))
lens.eval()
# Use the lens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
inputs = tokenizer("The capital of France is", return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(inputs.input_ids, return_hidden_states=True)
# Get predictions from each layer
for layer_idx in range(config["num_layers"]):
hidden = outputs.all_hidden_states[layer_idx + 1]
logits = lens.forward_layer(hidden, layer_idx)
pred_token = logits[0, -1].argmax()
print(f"Layer {layer_idx}: {tokenizer.decode([pred_token])}")
The repository includes example visualizations:
@article{belrose2023eliciting,
title={Eliciting Latent Predictions from Transformers with the Tuned Lens},
author={Belrose, Nora and Furman, Zach and Smith, Logan and Halawi, Danny and Ostrovsky, Igor and McKinney, Lev and Biderman, Stella and Steinhardt, Jacob},
journal={arXiv preprint arXiv:2303.08112},
year={2023}
}
Apache 2.0
Base model
cartesia-ai/Llamba-1B