|
|
--- |
|
|
license: mit |
|
|
language: en |
|
|
tags: |
|
|
- gpt2 |
|
|
- causal-lm |
|
|
- pytorch |
|
|
- transformer |
|
|
- pretraining |
|
|
- sft |
|
|
- question-answering |
|
|
- ultra-fineweb |
|
|
- custom-dataset |
|
|
|
|
|
model-index: |
|
|
- name: gpt2-124m-qa |
|
|
results: |
|
|
- task: |
|
|
name: Question Answering |
|
|
type: text-generation |
|
|
dataset: |
|
|
name: Custom QA Dataset (JSONL) |
|
|
type: jsonl |
|
|
metrics: |
|
|
- name: Loss |
|
|
type: loss |
|
|
value: 0.65 |
|
|
--- |
|
|
|
|
|
<p align="center"> |
|
|
|
|
|
<a href="https://huggingface.co/shubharthak/gpt2-124m-qa"> |
|
|
<img alt="Model Size" src="https://img.shields.io/badge/Model%20Size-124M-blue"> |
|
|
</a> |
|
|
|
|
|
<a href="https://huggingface.co/shubharthak/gpt2-124m-qa"> |
|
|
<img alt="Downloads" src="https://img.shields.io/huggingface/dl-daily/shubharthak/gpt2-124m-qa"> |
|
|
</a> |
|
|
|
|
|
<a href="https://huggingface.co/shubharthak/gpt2-124m-qa"> |
|
|
<img alt="Likes" src="https://img.shields.io/badge/HuggingFace-Likes-yellow"> |
|
|
</a> |
|
|
|
|
|
<a href="https://huggingface.co/spaces/yuntian-deng/flash-attention"> |
|
|
<img alt="Flash Attention" src="https://img.shields.io/badge/Flash%20Attention-Enabled-brightgreen"> |
|
|
</a> |
|
|
|
|
|
<a href="https://pytorch.org/"> |
|
|
<img alt="PyTorch" src="https://img.shields.io/badge/Framework-PyTorch-red"> |
|
|
</a> |
|
|
|
|
|
<a href="https://huggingface.co/docs"> |
|
|
<img alt="Task" src="https://img.shields.io/badge/Task-QA%20%2F%20CausalLM-purple"> |
|
|
</a> |
|
|
|
|
|
</p> |
|
|
|
|
|
|
|
|
# GPT-2 124M — Pretrained on Ultra-FineWeb Edu + QA SFT |
|
|
|
|
|
This repository contains two trained checkpoints of a custom **GPT-2 124M** model: |
|
|
|
|
|
- **Pretrained Model:** `model_09535.pt` |
|
|
→ Trained *from scratch* on **Ultra-FineWeb Edu (5B token subset)** |
|
|
- **QA SFT Model:** `qa-sft_best.pt` |
|
|
→ Fine-tuned using **Supervised Fine-Tuning (SFT)** on a curated **custom Q&A dataset** |
|
|
|
|
|
This model was implemented using a **from-scratch GPT-2 training pipeline**, *inspired by Andrej Karpathy’s engineering approach*, but trained independently with different datasets and objectives. |
|
|
|
|
|
--- |
|
|
|
|
|
## 📦 Model Versions |
|
|
|
|
|
### **1. Pretrained Model (`model_09535.pt`)** |
|
|
| Feature | Details | |
|
|
|--------|---------| |
|
|
| Parameters | 124M | |
|
|
| Layers | 12 | |
|
|
| Heads | 12 | |
|
|
| Hidden size | 768 | |
|
|
| Sequence length | 1024 | |
|
|
| Vocab size | 50304 | |
|
|
| Dataset | Ultra-FineWeb Edu (educational, high-quality web text) | |
|
|
| Purpose | General language modeling | |
|
|
|
|
|
**Goal:** Build a clean GPT-2 Small from-scratch to understand and implement a full LLM training pipeline. |
|
|
|
|
|
--- |
|
|
|
|
|
### **2. QA SFT Model (`qa-sft_best.pt`)** |
|
|
| Feature | Details | |
|
|
|--------|---------| |
|
|
| Base | The pretrained model above | |
|
|
| Method | Supervised Fine-Tuning (SFT) | |
|
|
| Dataset | Custom JSONL Q&A dataset | |
|
|
| Domain | Australian facts, general knowledge, definitions, reasoning | |
|
|
| Use-case | QA-style interactive chatbot | |
|
|
|
|
|
Demo available at: |
|
|
👉 **https://gpt2.devshubh.me** |
|
|
|
|
|
--- |
|
|
|
|
|
# 🧠 Model Architecture |
|
|
|
|
|
This model follows the **GPT-2 Small** architecture: |
|
|
|
|
|
- Decoder-only transformer |
|
|
- Multi-Head Self-Attention |
|
|
- GELU activations |
|
|
- LayerNorm (Pre-Norm) |
|
|
- Flash Attention enabled during training |
|
|
- Positional embeddings |
|
|
- Weight decay + AdamW (fused) |
|
|
- Mixed Precision (AMP FP16) |
|
|
|
|
|
--- |
|
|
|
|
|
# 🛠️ Training Details |
|
|
|
|
|
## **Pretraining on Ultra-FineWeb Edu (5B token subset)** |
|
|
|
|
|
- **Dataset:** Ultra-FineWeb Edu (educational, high-quality text) |
|
|
- **Tokenizer:** GPT-2 BPE (50304 vocab) |
|
|
- **Steps:** Thousands of steps on Kaggle T4 |
|
|
- **Techniques used:** |
|
|
- Flash Attention |
|
|
- Gradient Accumulation |
|
|
- FP16 AMP |
|
|
- Cosine Learning Rate Decay |
|
|
- Warmup |
|
|
- Fused AdamW |
|
|
- Weight Decay |
|
|
- Checkpointing every 500 steps |
|
|
|
|
|
--- |
|
|
|
|
|
## **Supervised Fine-Tuning (SFT) for QA** |
|
|
|
|
|
- **Dataset:** Custom QA JSONL |
|
|
- **Format:** `{"instruction": "...", "response": "..."}` |
|
|
- **Loss:** Cross-entropy |
|
|
- **Goal:** Improve chat quality + correctness for QA |
|
|
- **Result:** Stable ~0.6–0.7 loss, improved reasoning |
|
|
- **Tokens:** ~100K–200K from curated dataset |
|
|
|
|
|
--- |
|
|
|
|
|
# 📚 Datasets Used |
|
|
|
|
|
### **Pretraining Dataset: Ultra-FineWeb Edu** |
|
|
- Educational subset of Ultra-FineWeb |
|
|
- High-quality English text |
|
|
- Filtered for correctness |
|
|
- Contains textbook-like explanations |
|
|
- Clean enough to bootstrap small LLMs |
|
|
|
|
|
### **Fine-Tuning Dataset: Custom QA JSONL** |
|
|
- Australian knowledge |
|
|
- Definitions |
|
|
- Technology facts |
|
|
- Simple reasoning questions |
|
|
- Clean short answers |
|
|
|
|
|
--- |
|
|
|
|
|
# 🔤 Tokenizer |
|
|
|
|
|
- GPT-2 BPE |
|
|
- 50304 vocab |
|
|
- Identical formatting to GPT-2 tokenizer |
|
|
- Tokenization done via `tiktoken` |
|
|
|
|
|
--- |
|
|
|
|
|
# 💻 How to Use (Karpathy Repo) |
|
|
|
|
|
### **1. Clone the repo** |
|
|
```bash |
|
|
git clone https://github.com/shubharthaksangharsha/karpathy |
|
|
cd karpathy/chapter-9-sft-rhlf-dpo-gpt2-124m |
|
|
``` |
|
|
|
|
|
### **2. Run inference** |
|
|
```python |
|
|
import torch |
|
|
from model import GPT |
|
|
|
|
|
ckpt = torch.load("model_09535.pt", map_location="cpu") |
|
|
model = GPT(config=ckpt['config']) |
|
|
model.load_state_dict(ckpt['model']) |
|
|
model.eval() |
|
|
|
|
|
out = model.generate("Who is the prime minister of australia?", max_new_tokens=60) |
|
|
print(out) |
|
|
``` |
|
|
|
|
|
### **To run the QA model instead:** |
|
|
```python |
|
|
import torch |
|
|
from model import GPT |
|
|
|
|
|
ckpt = torch.load("qa-sft_best.pt", map_location="cpu") |
|
|
model = GPT(config=ckpt['config']) |
|
|
model.load_state_dict(ckpt['model']) |
|
|
model.eval() |
|
|
|
|
|
out = model.generate("What is the capital of Australia?", max_new_tokens=60) |
|
|
print(out) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
# 🤗 How to Use (Hugging Face Transformers) |
|
|
|
|
|
Because this is a **Karpathy-format checkpoint**, you cannot load it directly using: |
|
|
|
|
|
```python |
|
|
AutoModelForCausalLM.from_pretrained(...) |
|
|
``` |
|
|
|
|
|
Instead, load the state dict manually: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
state = torch.load("model_09535.pt", map_location="cpu") |
|
|
model = state["model"] |
|
|
``` |
|
|
|
|
|
⚠️ A conversion script is required for full HF `.from_pretrained()` compatibility. |
|
|
|
|
|
--- |
|
|
|
|
|
# 📝 Example Inference (QA Model) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from model import GPT |
|
|
from tokenizer import GPT2Tokenizer |
|
|
|
|
|
tokenizer = GPT2Tokenizer() |
|
|
|
|
|
ckpt = torch.load("qa-sft_best.pt") |
|
|
model = GPT(config=ckpt['config']) |
|
|
model.load_state_dict(ckpt['model']) |
|
|
model.eval() |
|
|
|
|
|
prompt = "Q: What is the capital of Australia?\nA:" |
|
|
tokens = tokenizer.encode(prompt) |
|
|
out = model.generate(tokens, max_new_tokens=60) |
|
|
print(tokenizer.decode(out)) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
# ⚠️ Limitations |
|
|
- Only 124M parameters (not SOTA) |
|
|
- Limited reasoning ability |
|
|
- Trained on small custom QA set |
|
|
- Not RLHF-finetuned (only SFT) |
|
|
- Not safety-aligned or filtered |
|
|
|
|
|
--- |
|
|
|
|
|
# 📄 License |
|
|
This work is based on Andrej Karpathy’s "Neural Networks: Zero to Hero" course and follows the same educational license. |
|
|
|