Llamba-1B Tuned Lens

A trained Tuned Lens for the Llamba-1B model (Mamba-based SSM language model).

What is Tuned Lens?

The Tuned Lens (Belrose et al., 2023) trains affine probes at each layer to project intermediate hidden states to vocabulary space. This enables visualization of how predictions evolve layer-by-layer during the forward pass.

For Mamba/SSM models, this reveals how the recurrent state refinement progressively builds the final prediction - early layers show high entropy (uncertainty), while later layers converge to the final output.

Model Details

Parameter Value
Base Model cartesia-ai/Llamba-1B
Architecture Mamba (SSM)
Layers 16
Hidden Dim 2048
Vocab Size 128,256
Training Data WikiText-2 (1644 samples)
Training 1 epoch

Training Metrics

Metric Value
Initial Loss ~3.84
Final Loss ~1.43
Layer 12 Loss 0.775
Layer 13 Loss 0.714
Layer 14 Loss 0.443
Layer 15 Loss 0.203

Usage

import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from cartesia_pytorch.Llamba import LlambaLMHeadModel
import json

# Download files
lens_path = hf_hub_download("Xeiroh/llamba-1b-tuned-lens", "lens.pt")
config_path = hf_hub_download("Xeiroh/llamba-1b-tuned-lens", "config.json")

# Load config
with open(config_path) as f:
    config = json.load(f)

# Load base model
model = LlambaLMHeadModel.from_pretrained(
    "cartesia-ai/Llamba-1B",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).cuda().eval()

# Create lens module
class MambaTunedLens(nn.Module):
    def __init__(self, d_model, vocab_size, num_layers, bias=True,
                 unembed_weight=None, final_norm_weight=None, final_norm_eps=1e-5):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        
        self.translators = nn.ModuleList([
            nn.Linear(d_model, d_model, bias=bias) for _ in range(num_layers)
        ])
        
        if unembed_weight is not None:
            self.register_buffer("unembed", unembed_weight.clone())
        if final_norm_weight is not None:
            self.register_buffer("final_norm_weight", final_norm_weight.clone())
        self.final_norm_eps = final_norm_eps
    
    def forward_layer(self, hidden_state, layer_idx):
        h = self.translators[layer_idx](hidden_state)
        h = h * torch.rsqrt(h.pow(2).mean(-1, keepdim=True) + self.final_norm_eps)
        h = h * self.final_norm_weight
        return h @ self.unembed.T

# Initialize lens with model weights
unembed_weight = model.lm_head.weight.data
final_norm_weight = model.backbone.final_layernorm.weight.data
final_norm_eps = model.backbone.final_layernorm.variance_epsilon

lens = MambaTunedLens(
    d_model=config["d_model"],
    vocab_size=config["vocab_size"],
    num_layers=config["num_layers"],
    bias=config["bias"],
    unembed_weight=unembed_weight,
    final_norm_weight=final_norm_weight,
    final_norm_eps=final_norm_eps,
).cuda()

# Load trained weights
lens.load_state_dict(torch.load(lens_path, weights_only=True))
lens.eval()

# Use the lens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
inputs = tokenizer("The capital of France is", return_tensors="pt").to("cuda")

with torch.no_grad():
    outputs = model(inputs.input_ids, return_hidden_states=True)
    
    # Get predictions from each layer
    for layer_idx in range(config["num_layers"]):
        hidden = outputs.all_hidden_states[layer_idx + 1]
        logits = lens.forward_layer(hidden, layer_idx)
        pred_token = logits[0, -1].argmax()
        print(f"Layer {layer_idx}: {tokenizer.decode([pred_token])}")

Visualizations

The repository includes example visualizations:

  • entropy_heatmap.png - Entropy (uncertainty) across layers and positions
  • steering_comparison.png - Before/after steering trajectory comparison
  • trajectory_heatmap.png - Layer-by-layer prediction agreement with final output
  • convergence_depth.png - At which layer each position converges to final prediction

Citation

@article{belrose2023eliciting,
  title={Eliciting Latent Predictions from Transformers with the Tuned Lens},
  author={Belrose, Nora and Furman, Zach and Smith, Logan and Halawi, Danny and Ostrovsky, Igor and McKinney, Lev and Biderman, Stella and Steinhardt, Jacob},
  journal={arXiv preprint arXiv:2303.08112},
  year={2023}
}

License

Apache 2.0

Downloads last month
14
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Xeiroh/llamba-1b-tuned-lens

Finetuned
(1)
this model

Dataset used to train Xeiroh/llamba-1b-tuned-lens

Paper for Xeiroh/llamba-1b-tuned-lens