LLaMA4-Tiny-VLM

A 470M parameter Vision-Language Model built entirely from scratch, implementing LLaMA 4 architecture innovations.

Model Description

This is an educational implementation of a complete VLM training pipeline, from text pretraining through preference optimization.

Architecture Highlights

Component Details
LLM 380M params, 12 layers, 768 hidden dim
Vision Encoder ViT-B/16 (86M params, frozen)
Projector 2-layer MLP (4.7M params)
Total ~470M parameters

Key Innovations:

  • Grouped Query Attention (GQA): 12 query heads, 4 KV heads (3x memory savings)
  • iRoPE: Interleaved RoPE/NoPE layers (3:1 pattern) with chunked attention
  • Mixture of Experts: 8 experts, top-2 routing, shared expert
  • SwiGLU: Gated activation in all FFN blocks

Training Pipeline

Phase Dataset Trainable Result
1. Text Pretraining HuggingFaceTB/smollm-corpus (cosmopedia-v2) 380M (100%) LLM base
2. VL Alignment jxie/coco_captions (567K) 4.7M (1%) val_loss: 3.23
3. Instruction Tuning liuhaotian/LLaVA-Instruct-150K (142K COCO) 385M (82%) val_loss: 1.69
4. DPO Tuning HuggingFaceH4/rlaif-v_formatted (79K) 385M (82%) val_acc: 64.5%

Checkpoints

File Description Size
checkpoints/text_pretraining_best.pt Phase 1: Pretrained LLM 3.1GB
checkpoints/vision_alignment_best.pt Phase 2: Aligned projector 1.8GB
checkpoints/instruction_tuning_best.pt Phase 3: Instruction-tuned 3.5GB
checkpoints/dpo_best.pt Phase 4: DPO-tuned (final) 3.5GB

Usage

# Clone repo and install dependencies
git clone https://github.com/siddhantmedar/llama4-tiny-vlm
cd llama4-tiny-vlm
pip install torch torchvision huggingface_hub tokenizers
import torch
import tomllib
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from tokenizers import Tokenizer

# Download checkpoints (cached automatically)
llm_ckpt = hf_hub_download("medarsiddhant/llama4-tiny-vlm", "checkpoints/text_pretraining_best.pt")
vlm_ckpt = hf_hub_download("medarsiddhant/llama4-tiny-vlm", "checkpoints/vision_alignment_best.pt")
dpo_ckpt = hf_hub_download("medarsiddhant/llama4-tiny-vlm", "checkpoints/dpo_best.pt")

# Load model
from visual_instruction_tuning.model import create_instruct_vlm

with open("config.toml", "rb") as f:
    config = tomllib.load(f)

model = create_instruct_vlm(config, llm_ckpt, vlm_ckpt)
ckpt = torch.load(dpo_ckpt, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to("cuda").eval()

# Load tokenizer
tokenizer = Tokenizer.from_file("visual_instruction_tuning/bpe_tokenizer_with_image_tag.json")
IMAGE_TOKEN_ID = tokenizer.token_to_id("<image>")
EOS_TOKEN_ID = tokenizer.token_to_id("</s>")

# Image transform
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Inference function
@torch.no_grad()
def ask(image_path, question, max_tokens=100):
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to("cuda")

    prompt = f" USER: {question} ASSISTANT:"
    input_ids = [IMAGE_TOKEN_ID] + tokenizer.encode(prompt).ids
    input_ids = torch.tensor([input_ids], device="cuda")

    generated = []
    for _ in range(max_tokens):
        logits = model(image_tensor, input_ids)
        next_token = logits[0, -1, :].argmax().item()
        if next_token == EOS_TOKEN_ID:
            break
        generated.append(next_token)
        input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device="cuda")], dim=1)

    return tokenizer.decode(generated)

# Example usage
response = ask("path/to/image.jpg", "What is in this image?")
print(response)

Limitations

  • Small model size: 470M params limits reasoning capacity compared to 7B+ models
  • Repetition: May produce repetitive outputs (use repetition_penalty=1.2)
  • Training data: Limited to COCO images and synthetic captions
  • Educational purpose: Not intended for production use

Training Hardware

  • GPU: NVIDIA RTX 3090 (24GB)
  • Total training time: ~35-40 hours across all phases

Links

Citation

@misc{llama4-tiny-vlm,
  author = {Siddhant Medar},
  title = {LLaMA4-Tiny-VLM: A Vision-Language Model from Scratch},
  year = {2025},
  publisher = {HuggingFace},
  url = {https://huggingface.co/medarsiddhant/llama4-tiny-vlm}
}

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train medarsiddhant/llama4-tiny-vlm