|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Training Data Loader for Language Model Training
|
|
|
|
|
|
This module provides efficient data loading and batching for training GPT-style
|
|
|
language models. It handles text preprocessing, tokenization, and creates
|
|
|
batches suitable for autoregressive language modeling.
|
|
|
|
|
|
FEATURES:
|
|
|
- Memory-efficient text loading with sliding window
|
|
|
- Automatic tokenization using trained SentencePiece model
|
|
|
- Configurable sequence length and batch size
|
|
|
- CPU-optimized data loading for limited hardware
|
|
|
- Support for training data validation and statistics
|
|
|
|
|
|
MEMORY OPTIMIZATION:
|
|
|
- Streaming data loading (doesn't load entire dataset to memory)
|
|
|
- Configurable chunk sizes for large files
|
|
|
- Efficient tensor creation and batching
|
|
|
- Garbage collection hints for memory management
|
|
|
|
|
|
Usage:
|
|
|
from data_loader import TextDataLoader
|
|
|
|
|
|
loader = TextDataLoader(
|
|
|
data_file="data/clean/training_data.txt",
|
|
|
tokenizer_path="data/tokenizer/tokenizer.model",
|
|
|
seq_len=512,
|
|
|
batch_size=4
|
|
|
)
|
|
|
|
|
|
for batch in loader:
|
|
|
input_ids, targets = batch
|
|
|
# input_ids: (batch_size, seq_len)
|
|
|
# targets: (batch_size, seq_len) - shifted by 1 for next token prediction
|
|
|
|
|
|
Author: Louis Chua Bean Chong
|
|
|
License: GPLv3
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import gc
|
|
|
import random
|
|
|
import torch
|
|
|
import time
|
|
|
from typing import Iterator, Tuple, List, Optional
|
|
|
from pathlib import Path
|
|
|
|
|
|
try:
|
|
|
import sentencepiece as spm
|
|
|
except ImportError:
|
|
|
print("ERROR: SentencePiece not installed. Run: pip install sentencepiece")
|
|
|
exit(1)
|
|
|
|
|
|
|
|
|
class TextDataLoader:
|
|
|
"""
|
|
|
Efficient data loader for autoregressive language model training.
|
|
|
|
|
|
This class handles loading text data, tokenizing it using SentencePiece,
|
|
|
and creating batches suitable for next-token prediction training.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
data_file: str,
|
|
|
tokenizer_path: str,
|
|
|
seq_len: int = 512,
|
|
|
batch_size: int = 4,
|
|
|
chunk_size: int = 1000000,
|
|
|
shuffle: bool = True,
|
|
|
seed: int = 42
|
|
|
):
|
|
|
"""
|
|
|
Initialize the data loader.
|
|
|
|
|
|
Args:
|
|
|
data_file: Path to training text file (one passage per line)
|
|
|
tokenizer_path: Path to trained SentencePiece model
|
|
|
seq_len: Maximum sequence length for training
|
|
|
batch_size: Batch size for training
|
|
|
chunk_size: Number of lines to read in memory at once
|
|
|
shuffle: Whether to shuffle training examples
|
|
|
seed: Random seed for reproducibility
|
|
|
"""
|
|
|
self.data_file = data_file
|
|
|
self.tokenizer_path = tokenizer_path
|
|
|
self.seq_len = seq_len
|
|
|
self.batch_size = batch_size
|
|
|
self.chunk_size = chunk_size
|
|
|
self.shuffle = shuffle
|
|
|
self.seed = seed
|
|
|
|
|
|
|
|
|
self._validate_inputs()
|
|
|
|
|
|
|
|
|
self.tokenizer = self._load_tokenizer()
|
|
|
|
|
|
|
|
|
self.total_lines = self._count_lines()
|
|
|
self.current_line = 0
|
|
|
|
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
print(f"π TextDataLoader initialized")
|
|
|
print(f" Data file: {data_file}")
|
|
|
print(f" Total passages: {self.total_lines:,}")
|
|
|
print(f" Sequence length: {seq_len}")
|
|
|
print(f" Batch size: {batch_size}")
|
|
|
print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
|
|
|
|
|
def _validate_inputs(self) -> None:
|
|
|
"""Validate input parameters and file paths."""
|
|
|
if not os.path.exists(self.data_file):
|
|
|
raise FileNotFoundError(f"Training data file not found: {self.data_file}")
|
|
|
|
|
|
if not os.path.exists(self.tokenizer_path):
|
|
|
raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
|
|
|
|
|
|
if self.seq_len <= 0:
|
|
|
raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
|
|
|
|
|
|
if self.batch_size <= 0:
|
|
|
raise ValueError(f"Batch size must be positive, got {self.batch_size}")
|
|
|
|
|
|
if self.chunk_size <= 0:
|
|
|
raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
|
|
|
|
|
|
def _load_tokenizer(self) -> spm.SentencePieceProcessor:
|
|
|
"""Load the trained SentencePiece tokenizer."""
|
|
|
try:
|
|
|
tokenizer = spm.SentencePieceProcessor()
|
|
|
tokenizer.load(self.tokenizer_path)
|
|
|
return tokenizer
|
|
|
except Exception as e:
|
|
|
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
|
|
|
|
|
def _count_lines(self) -> int:
|
|
|
"""Count total number of lines in the data file."""
|
|
|
print("π Counting training passages...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
line_count = 0
|
|
|
with open(self.data_file, 'r', encoding='utf-8') as f:
|
|
|
for line in f:
|
|
|
if line.strip():
|
|
|
line_count += 1
|
|
|
|
|
|
count_time = time.time() - start_time
|
|
|
print(f"β Found {line_count:,} passages in {count_time:.1f}s")
|
|
|
|
|
|
return line_count
|
|
|
|
|
|
def _read_chunk(self, start_line: int = 0) -> List[str]:
|
|
|
"""
|
|
|
Read a chunk of lines from the data file.
|
|
|
|
|
|
Args:
|
|
|
start_line: Line number to start reading from
|
|
|
|
|
|
Returns:
|
|
|
List of text passages
|
|
|
"""
|
|
|
chunk = []
|
|
|
current_line = 0
|
|
|
lines_read = 0
|
|
|
|
|
|
with open(self.data_file, 'r', encoding='utf-8') as f:
|
|
|
for line in f:
|
|
|
if current_line < start_line:
|
|
|
current_line += 1
|
|
|
continue
|
|
|
|
|
|
text = line.strip()
|
|
|
if text:
|
|
|
chunk.append(text)
|
|
|
lines_read += 1
|
|
|
|
|
|
if lines_read >= self.chunk_size:
|
|
|
break
|
|
|
|
|
|
current_line += 1
|
|
|
|
|
|
return chunk
|
|
|
|
|
|
def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
|
|
|
"""
|
|
|
Tokenize a list of text passages using SentencePiece tokenizer.
|
|
|
|
|
|
This method converts raw text into token ID sequences suitable for language model training.
|
|
|
It handles special tokens (BOS/EOS) and length constraints for efficient training.
|
|
|
|
|
|
Text processing pipeline:
|
|
|
1. Add BOS (Beginning of Sequence) token to mark sequence start
|
|
|
2. Tokenize text using trained SentencePiece model (subword tokenization)
|
|
|
3. Truncate sequences that exceed maximum length
|
|
|
4. Add EOS (End of Sequence) token to mark sequence end
|
|
|
|
|
|
Special token handling:
|
|
|
- BOS token helps model learn to generate text from scratch
|
|
|
- EOS token signals natural sequence endings
|
|
|
- These tokens are crucial for proper autoregressive generation
|
|
|
|
|
|
Args:
|
|
|
texts: List of text passages (typically Wikipedia passages from SQUAD)
|
|
|
Each passage should be a complete, coherent text segment
|
|
|
|
|
|
Returns:
|
|
|
List of token ID sequences, where each sequence is a list of integers
|
|
|
representing subword tokens from the SentencePiece vocabulary
|
|
|
"""
|
|
|
tokenized = []
|
|
|
|
|
|
for text in texts:
|
|
|
try:
|
|
|
|
|
|
|
|
|
|
|
|
tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(tokens) > self.seq_len - 1:
|
|
|
tokens = tokens[:self.seq_len - 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens.append(self.tokenizer.eos_id())
|
|
|
|
|
|
|
|
|
if len(tokens) <= 2:
|
|
|
print(f"β οΈ Skipping very short text: {text[:50]}...")
|
|
|
continue
|
|
|
|
|
|
tokenized.append(tokens)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
|
print(f"β οΈ Failed to tokenize passage: {text[:50]}... Error: {e}")
|
|
|
continue
|
|
|
|
|
|
|
|
|
if tokenized:
|
|
|
avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
|
|
|
print(f"π Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
|
|
|
|
|
|
return tokenized
|
|
|
|
|
|
def _create_training_examples(self, token_sequences: List[List[int]]) -> List[Tuple[List[int], List[int]]]:
|
|
|
"""
|
|
|
Create training examples with input and target sequences.
|
|
|
|
|
|
For autoregressive training, targets are inputs shifted by one position.
|
|
|
|
|
|
Args:
|
|
|
token_sequences: List of tokenized sequences
|
|
|
|
|
|
Returns:
|
|
|
List of (input_ids, target_ids) tuples
|
|
|
"""
|
|
|
examples = []
|
|
|
|
|
|
for tokens in token_sequences:
|
|
|
if len(tokens) < 2:
|
|
|
continue
|
|
|
|
|
|
|
|
|
if len(tokens) > self.seq_len:
|
|
|
|
|
|
stride = self.seq_len // 2
|
|
|
for i in range(0, len(tokens) - self.seq_len, stride):
|
|
|
input_ids = tokens[i:i + self.seq_len]
|
|
|
target_ids = tokens[i + 1:i + self.seq_len + 1]
|
|
|
examples.append((input_ids, target_ids))
|
|
|
else:
|
|
|
|
|
|
input_ids = tokens[:-1]
|
|
|
target_ids = tokens[1:]
|
|
|
|
|
|
|
|
|
while len(input_ids) < self.seq_len:
|
|
|
input_ids.append(self.tokenizer.pad_id())
|
|
|
target_ids.append(-1)
|
|
|
|
|
|
|
|
|
input_ids = input_ids[:self.seq_len]
|
|
|
target_ids = target_ids[:self.seq_len]
|
|
|
|
|
|
examples.append((input_ids, target_ids))
|
|
|
|
|
|
return examples
|
|
|
|
|
|
def _create_batch(self, examples: List[Tuple[List[int], List[int]]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Create a batch tensor from training examples.
|
|
|
|
|
|
Args:
|
|
|
examples: List of (input_ids, target_ids) tuples
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (input_tensor, target_tensor)
|
|
|
"""
|
|
|
if not examples:
|
|
|
raise ValueError("Cannot create batch from empty examples")
|
|
|
|
|
|
batch_size = len(examples)
|
|
|
|
|
|
|
|
|
input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
|
|
|
target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
|
|
|
|
|
|
|
|
|
for i, (inp, tgt) in enumerate(examples):
|
|
|
input_ids[i, :len(inp)] = torch.tensor(inp, dtype=torch.long)
|
|
|
target_ids[i, :len(tgt)] = torch.tensor(tgt, dtype=torch.long)
|
|
|
|
|
|
return input_ids, target_ids
|
|
|
|
|
|
def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
"""
|
|
|
Iterate over training batches.
|
|
|
|
|
|
Yields:
|
|
|
Tuple of (input_ids, target_ids) tensors
|
|
|
"""
|
|
|
self.current_line = 0
|
|
|
|
|
|
while self.current_line < self.total_lines:
|
|
|
|
|
|
texts = self._read_chunk(self.current_line)
|
|
|
if not texts:
|
|
|
break
|
|
|
|
|
|
|
|
|
token_sequences = self._tokenize_texts(texts)
|
|
|
|
|
|
|
|
|
examples = self._create_training_examples(token_sequences)
|
|
|
|
|
|
|
|
|
if self.shuffle:
|
|
|
random.shuffle(examples)
|
|
|
|
|
|
|
|
|
for i in range(0, len(examples), self.batch_size):
|
|
|
batch_examples = examples[i:i + self.batch_size]
|
|
|
|
|
|
if len(batch_examples) == self.batch_size:
|
|
|
try:
|
|
|
input_ids, target_ids = self._create_batch(batch_examples)
|
|
|
yield input_ids, target_ids
|
|
|
except Exception as e:
|
|
|
print(f"β οΈ Failed to create batch: {e}")
|
|
|
continue
|
|
|
|
|
|
|
|
|
self.current_line += len(texts)
|
|
|
|
|
|
|
|
|
del texts, token_sequences, examples
|
|
|
gc.collect()
|
|
|
|
|
|
def get_data_stats(self) -> dict:
|
|
|
"""
|
|
|
Get statistics about the training data.
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with data statistics
|
|
|
"""
|
|
|
print("π Analyzing training data...")
|
|
|
|
|
|
|
|
|
sample_texts = self._read_chunk(0)[:100]
|
|
|
token_sequences = self._tokenize_texts(sample_texts)
|
|
|
|
|
|
if token_sequences:
|
|
|
sequence_lengths = [len(seq) for seq in token_sequences]
|
|
|
avg_length = sum(sequence_lengths) / len(sequence_lengths)
|
|
|
max_length = max(sequence_lengths)
|
|
|
min_length = min(sequence_lengths)
|
|
|
else:
|
|
|
avg_length = max_length = min_length = 0
|
|
|
|
|
|
|
|
|
estimated_total_tokens = int(avg_length * self.total_lines)
|
|
|
|
|
|
|
|
|
examples_per_passage = max(1, avg_length // self.seq_len)
|
|
|
total_examples = int(self.total_lines * examples_per_passage)
|
|
|
batches_per_epoch = total_examples // self.batch_size
|
|
|
|
|
|
stats = {
|
|
|
"total_passages": self.total_lines,
|
|
|
"avg_tokens_per_passage": avg_length,
|
|
|
"min_tokens_per_passage": min_length,
|
|
|
"max_tokens_per_passage": max_length,
|
|
|
"estimated_total_tokens": estimated_total_tokens,
|
|
|
"estimated_examples_per_epoch": total_examples,
|
|
|
"estimated_batches_per_epoch": batches_per_epoch,
|
|
|
"sequence_length": self.seq_len,
|
|
|
"batch_size": self.batch_size,
|
|
|
"vocabulary_size": self.tokenizer.vocab_size()
|
|
|
}
|
|
|
|
|
|
print(f"β Data analysis complete:")
|
|
|
print(f" Total passages: {stats['total_passages']:,}")
|
|
|
print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
|
|
|
print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
|
|
|
print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
|
|
|
|
|
|
return stats
|
|
|
|
|
|
|
|
|
def test_data_loader():
|
|
|
"""Test function for the data loader."""
|
|
|
print("π§ͺ Testing TextDataLoader...")
|
|
|
|
|
|
|
|
|
try:
|
|
|
loader = TextDataLoader(
|
|
|
data_file="data/clean/training_data.txt",
|
|
|
tokenizer_path="data/tokenizer/tokenizer.model",
|
|
|
seq_len=128,
|
|
|
batch_size=2,
|
|
|
chunk_size=10
|
|
|
)
|
|
|
|
|
|
|
|
|
stats = loader.get_data_stats()
|
|
|
|
|
|
|
|
|
print("\nπ Testing batch iteration...")
|
|
|
start_time = time.time()
|
|
|
batch_count = 0
|
|
|
|
|
|
for batch_idx, (input_ids, target_ids) in enumerate(loader):
|
|
|
batch_count += 1
|
|
|
|
|
|
print(f"Batch {batch_idx + 1}:")
|
|
|
print(f" Input shape: {input_ids.shape}")
|
|
|
print(f" Target shape: {target_ids.shape}")
|
|
|
print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
|
|
|
print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
|
|
|
|
|
|
if batch_idx >= 2:
|
|
|
break
|
|
|
|
|
|
test_time = time.time() - start_time
|
|
|
print(f"\nβ Data loader test completed successfully!")
|
|
|
print(f" Processed {batch_count} batches in {test_time:.2f}s")
|
|
|
print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"β Data loader test failed: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
return False
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_data_loader() |