lemms commited on
Commit
1da0951
Β·
verified Β·
1 Parent(s): e15a2f9

Add OpenLLM data_loader.py source file

Browse files
Files changed (1) hide show
  1. data_loader.py +480 -0
data_loader.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024 Louis Chua Bean Chong
3
+ #
4
+ # This file is part of OpenLLM.
5
+ #
6
+ # OpenLLM is dual-licensed:
7
+ # 1. For open source use: GNU General Public License v3.0
8
+ # 2. For commercial use: Commercial License (contact for details)
9
+ #
10
+ # See LICENSE and docs/LICENSES.md for full license information.
11
+
12
+ """
13
+ Training Data Loader for Language Model Training
14
+
15
+ This module provides efficient data loading and batching for training GPT-style
16
+ language models. It handles text preprocessing, tokenization, and creates
17
+ batches suitable for autoregressive language modeling.
18
+
19
+ FEATURES:
20
+ - Memory-efficient text loading with sliding window
21
+ - Automatic tokenization using trained SentencePiece model
22
+ - Configurable sequence length and batch size
23
+ - CPU-optimized data loading for limited hardware
24
+ - Support for training data validation and statistics
25
+
26
+ MEMORY OPTIMIZATION:
27
+ - Streaming data loading (doesn't load entire dataset to memory)
28
+ - Configurable chunk sizes for large files
29
+ - Efficient tensor creation and batching
30
+ - Garbage collection hints for memory management
31
+
32
+ Usage:
33
+ from data_loader import TextDataLoader
34
+
35
+ loader = TextDataLoader(
36
+ data_file="data/clean/training_data.txt",
37
+ tokenizer_path="data/tokenizer/tokenizer.model",
38
+ seq_len=512,
39
+ batch_size=4
40
+ )
41
+
42
+ for batch in loader:
43
+ input_ids, targets = batch
44
+ # input_ids: (batch_size, seq_len)
45
+ # targets: (batch_size, seq_len) - shifted by 1 for next token prediction
46
+
47
+ Author: Louis Chua Bean Chong
48
+ License: GPLv3
49
+ """
50
+
51
+ import os
52
+ import gc
53
+ import random
54
+ import torch
55
+ import time
56
+ from typing import Iterator, Tuple, List, Optional
57
+ from pathlib import Path
58
+
59
+ try:
60
+ import sentencepiece as spm
61
+ except ImportError:
62
+ print("ERROR: SentencePiece not installed. Run: pip install sentencepiece")
63
+ exit(1)
64
+
65
+
66
+ class TextDataLoader:
67
+ """
68
+ Efficient data loader for autoregressive language model training.
69
+
70
+ This class handles loading text data, tokenizing it using SentencePiece,
71
+ and creating batches suitable for next-token prediction training.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ data_file: str,
77
+ tokenizer_path: str,
78
+ seq_len: int = 512,
79
+ batch_size: int = 4,
80
+ chunk_size: int = 1000000, # Lines to read at once
81
+ shuffle: bool = True,
82
+ seed: int = 42
83
+ ):
84
+ """
85
+ Initialize the data loader.
86
+
87
+ Args:
88
+ data_file: Path to training text file (one passage per line)
89
+ tokenizer_path: Path to trained SentencePiece model
90
+ seq_len: Maximum sequence length for training
91
+ batch_size: Batch size for training
92
+ chunk_size: Number of lines to read in memory at once
93
+ shuffle: Whether to shuffle training examples
94
+ seed: Random seed for reproducibility
95
+ """
96
+ self.data_file = data_file
97
+ self.tokenizer_path = tokenizer_path
98
+ self.seq_len = seq_len
99
+ self.batch_size = batch_size
100
+ self.chunk_size = chunk_size
101
+ self.shuffle = shuffle
102
+ self.seed = seed
103
+
104
+ # Validate inputs
105
+ self._validate_inputs()
106
+
107
+ # Load tokenizer
108
+ self.tokenizer = self._load_tokenizer()
109
+
110
+ # Get data statistics
111
+ self.total_lines = self._count_lines()
112
+ self.current_line = 0
113
+
114
+ # Set random seed for reproducibility
115
+ random.seed(seed)
116
+
117
+ print(f"πŸ“Š TextDataLoader initialized")
118
+ print(f" Data file: {data_file}")
119
+ print(f" Total passages: {self.total_lines:,}")
120
+ print(f" Sequence length: {seq_len}")
121
+ print(f" Batch size: {batch_size}")
122
+ print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
123
+
124
+ def _validate_inputs(self) -> None:
125
+ """Validate input parameters and file paths."""
126
+ if not os.path.exists(self.data_file):
127
+ raise FileNotFoundError(f"Training data file not found: {self.data_file}")
128
+
129
+ if not os.path.exists(self.tokenizer_path):
130
+ raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
131
+
132
+ if self.seq_len <= 0:
133
+ raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
134
+
135
+ if self.batch_size <= 0:
136
+ raise ValueError(f"Batch size must be positive, got {self.batch_size}")
137
+
138
+ if self.chunk_size <= 0:
139
+ raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
140
+
141
+ def _load_tokenizer(self) -> spm.SentencePieceProcessor:
142
+ """Load the trained SentencePiece tokenizer."""
143
+ try:
144
+ tokenizer = spm.SentencePieceProcessor()
145
+ tokenizer.load(self.tokenizer_path)
146
+ return tokenizer
147
+ except Exception as e:
148
+ raise RuntimeError(f"Failed to load tokenizer: {e}")
149
+
150
+ def _count_lines(self) -> int:
151
+ """Count total number of lines in the data file."""
152
+ print("πŸ“ Counting training passages...")
153
+ start_time = time.time()
154
+
155
+ line_count = 0
156
+ with open(self.data_file, 'r', encoding='utf-8') as f:
157
+ for line in f:
158
+ if line.strip(): # Only count non-empty lines
159
+ line_count += 1
160
+
161
+ count_time = time.time() - start_time
162
+ print(f"βœ“ Found {line_count:,} passages in {count_time:.1f}s")
163
+
164
+ return line_count
165
+
166
+ def _read_chunk(self, start_line: int = 0) -> List[str]:
167
+ """
168
+ Read a chunk of lines from the data file.
169
+
170
+ Args:
171
+ start_line: Line number to start reading from
172
+
173
+ Returns:
174
+ List of text passages
175
+ """
176
+ chunk = []
177
+ current_line = 0
178
+ lines_read = 0
179
+
180
+ with open(self.data_file, 'r', encoding='utf-8') as f:
181
+ for line in f:
182
+ if current_line < start_line:
183
+ current_line += 1
184
+ continue
185
+
186
+ text = line.strip()
187
+ if text: # Only include non-empty lines
188
+ chunk.append(text)
189
+ lines_read += 1
190
+
191
+ if lines_read >= self.chunk_size:
192
+ break
193
+
194
+ current_line += 1
195
+
196
+ return chunk
197
+
198
+ def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
199
+ """
200
+ Tokenize a list of text passages using SentencePiece tokenizer.
201
+
202
+ This method converts raw text into token ID sequences suitable for language model training.
203
+ It handles special tokens (BOS/EOS) and length constraints for efficient training.
204
+
205
+ Text processing pipeline:
206
+ 1. Add BOS (Beginning of Sequence) token to mark sequence start
207
+ 2. Tokenize text using trained SentencePiece model (subword tokenization)
208
+ 3. Truncate sequences that exceed maximum length
209
+ 4. Add EOS (End of Sequence) token to mark sequence end
210
+
211
+ Special token handling:
212
+ - BOS token helps model learn to generate text from scratch
213
+ - EOS token signals natural sequence endings
214
+ - These tokens are crucial for proper autoregressive generation
215
+
216
+ Args:
217
+ texts: List of text passages (typically Wikipedia passages from SQUAD)
218
+ Each passage should be a complete, coherent text segment
219
+
220
+ Returns:
221
+ List of token ID sequences, where each sequence is a list of integers
222
+ representing subword tokens from the SentencePiece vocabulary
223
+ """
224
+ tokenized = []
225
+
226
+ for text in texts:
227
+ try:
228
+ # Add BOS (Beginning of Sequence) token at the start
229
+ # BOS token ID=2 by default in SentencePiece, signals sequence start
230
+ # This helps the model learn proper sequence initialization during generation
231
+ tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
232
+
233
+ # Truncate sequences that exceed maximum context length
234
+ # Reserve one position for EOS token by using (seq_len - 1)
235
+ # This ensures we never exceed the model's context window during training
236
+ if len(tokens) > self.seq_len - 1:
237
+ tokens = tokens[:self.seq_len - 1]
238
+ # NOTE: Truncation may cut off text mid-sentence, but this is acceptable
239
+ # for language modeling where the model learns from partial contexts
240
+
241
+ # Add EOS (End of Sequence) token at the end
242
+ # EOS token ID=1 by default in SentencePiece, signals sequence completion
243
+ # This teaches the model when to stop generating text naturally
244
+ tokens.append(self.tokenizer.eos_id())
245
+
246
+ # Validate tokenization result
247
+ if len(tokens) <= 2: # Only BOS + EOS tokens, no actual content
248
+ print(f"⚠️ Skipping very short text: {text[:50]}...")
249
+ continue
250
+
251
+ tokenized.append(tokens)
252
+
253
+ except Exception as e:
254
+ # Handle tokenization errors gracefully to avoid stopping training
255
+ # Common causes: encoding issues, very long texts, special characters
256
+ print(f"⚠️ Failed to tokenize passage: {text[:50]}... Error: {e}")
257
+ continue
258
+
259
+ # Log tokenization statistics for monitoring
260
+ if tokenized:
261
+ avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
262
+ print(f"πŸ“Š Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
263
+
264
+ return tokenized
265
+
266
+ def _create_training_examples(self, token_sequences: List[List[int]]) -> List[Tuple[List[int], List[int]]]:
267
+ """
268
+ Create training examples with input and target sequences.
269
+
270
+ For autoregressive training, targets are inputs shifted by one position.
271
+
272
+ Args:
273
+ token_sequences: List of tokenized sequences
274
+
275
+ Returns:
276
+ List of (input_ids, target_ids) tuples
277
+ """
278
+ examples = []
279
+
280
+ for tokens in token_sequences:
281
+ if len(tokens) < 2: # Need at least 2 tokens for input/target pair
282
+ continue
283
+
284
+ # For sequences longer than seq_len, create multiple examples with sliding window
285
+ if len(tokens) > self.seq_len:
286
+ # Create overlapping windows (50% overlap for better learning)
287
+ stride = self.seq_len // 2
288
+ for i in range(0, len(tokens) - self.seq_len, stride):
289
+ input_ids = tokens[i:i + self.seq_len]
290
+ target_ids = tokens[i + 1:i + self.seq_len + 1]
291
+ examples.append((input_ids, target_ids))
292
+ else:
293
+ # Pad shorter sequences
294
+ input_ids = tokens[:-1] # All but last token
295
+ target_ids = tokens[1:] # All but first token
296
+
297
+ # Pad to seq_len if necessary
298
+ while len(input_ids) < self.seq_len:
299
+ input_ids.append(self.tokenizer.pad_id())
300
+ target_ids.append(-1) # Use -1 for padding in targets (ignored in loss)
301
+
302
+ # Truncate if still too long
303
+ input_ids = input_ids[:self.seq_len]
304
+ target_ids = target_ids[:self.seq_len]
305
+
306
+ examples.append((input_ids, target_ids))
307
+
308
+ return examples
309
+
310
+ def _create_batch(self, examples: List[Tuple[List[int], List[int]]]) -> Tuple[torch.Tensor, torch.Tensor]:
311
+ """
312
+ Create a batch tensor from training examples.
313
+
314
+ Args:
315
+ examples: List of (input_ids, target_ids) tuples
316
+
317
+ Returns:
318
+ Tuple of (input_tensor, target_tensor)
319
+ """
320
+ if not examples:
321
+ raise ValueError("Cannot create batch from empty examples")
322
+
323
+ batch_size = len(examples)
324
+
325
+ # Initialize tensors
326
+ input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
327
+ target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
328
+
329
+ # Fill tensors
330
+ for i, (inp, tgt) in enumerate(examples):
331
+ input_ids[i, :len(inp)] = torch.tensor(inp, dtype=torch.long)
332
+ target_ids[i, :len(tgt)] = torch.tensor(tgt, dtype=torch.long)
333
+
334
+ return input_ids, target_ids
335
+
336
+ def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
337
+ """
338
+ Iterate over training batches.
339
+
340
+ Yields:
341
+ Tuple of (input_ids, target_ids) tensors
342
+ """
343
+ self.current_line = 0
344
+
345
+ while self.current_line < self.total_lines:
346
+ # Read chunk of text
347
+ texts = self._read_chunk(self.current_line)
348
+ if not texts:
349
+ break
350
+
351
+ # Tokenize texts
352
+ token_sequences = self._tokenize_texts(texts)
353
+
354
+ # Create training examples
355
+ examples = self._create_training_examples(token_sequences)
356
+
357
+ # Shuffle examples if requested
358
+ if self.shuffle:
359
+ random.shuffle(examples)
360
+
361
+ # Create batches
362
+ for i in range(0, len(examples), self.batch_size):
363
+ batch_examples = examples[i:i + self.batch_size]
364
+
365
+ if len(batch_examples) == self.batch_size: # Only yield full batches
366
+ try:
367
+ input_ids, target_ids = self._create_batch(batch_examples)
368
+ yield input_ids, target_ids
369
+ except Exception as e:
370
+ print(f"⚠️ Failed to create batch: {e}")
371
+ continue
372
+
373
+ # Update progress
374
+ self.current_line += len(texts)
375
+
376
+ # Clean up memory
377
+ del texts, token_sequences, examples
378
+ gc.collect()
379
+
380
+ def get_data_stats(self) -> dict:
381
+ """
382
+ Get statistics about the training data.
383
+
384
+ Returns:
385
+ Dictionary with data statistics
386
+ """
387
+ print("πŸ“Š Analyzing training data...")
388
+
389
+ # Sample some data to get statistics
390
+ sample_texts = self._read_chunk(0)[:100] # Sample first 100 passages
391
+ token_sequences = self._tokenize_texts(sample_texts)
392
+
393
+ if token_sequences:
394
+ sequence_lengths = [len(seq) for seq in token_sequences]
395
+ avg_length = sum(sequence_lengths) / len(sequence_lengths)
396
+ max_length = max(sequence_lengths)
397
+ min_length = min(sequence_lengths)
398
+ else:
399
+ avg_length = max_length = min_length = 0
400
+
401
+ # Estimate total tokens
402
+ estimated_total_tokens = int(avg_length * self.total_lines)
403
+
404
+ # Estimate number of batches per epoch
405
+ examples_per_passage = max(1, avg_length // self.seq_len)
406
+ total_examples = int(self.total_lines * examples_per_passage)
407
+ batches_per_epoch = total_examples // self.batch_size
408
+
409
+ stats = {
410
+ "total_passages": self.total_lines,
411
+ "avg_tokens_per_passage": avg_length,
412
+ "min_tokens_per_passage": min_length,
413
+ "max_tokens_per_passage": max_length,
414
+ "estimated_total_tokens": estimated_total_tokens,
415
+ "estimated_examples_per_epoch": total_examples,
416
+ "estimated_batches_per_epoch": batches_per_epoch,
417
+ "sequence_length": self.seq_len,
418
+ "batch_size": self.batch_size,
419
+ "vocabulary_size": self.tokenizer.vocab_size()
420
+ }
421
+
422
+ print(f"βœ“ Data analysis complete:")
423
+ print(f" Total passages: {stats['total_passages']:,}")
424
+ print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
425
+ print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
426
+ print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
427
+
428
+ return stats
429
+
430
+
431
+ def test_data_loader():
432
+ """Test function for the data loader."""
433
+ print("πŸ§ͺ Testing TextDataLoader...")
434
+
435
+ # Test with small parameters
436
+ try:
437
+ loader = TextDataLoader(
438
+ data_file="data/clean/training_data.txt",
439
+ tokenizer_path="data/tokenizer/tokenizer.model",
440
+ seq_len=128,
441
+ batch_size=2,
442
+ chunk_size=10 # Small for testing
443
+ )
444
+
445
+ # Get data statistics
446
+ stats = loader.get_data_stats()
447
+
448
+ # Test iteration
449
+ print("\nπŸ”„ Testing batch iteration...")
450
+ start_time = time.time()
451
+ batch_count = 0
452
+
453
+ for batch_idx, (input_ids, target_ids) in enumerate(loader):
454
+ batch_count += 1
455
+
456
+ print(f"Batch {batch_idx + 1}:")
457
+ print(f" Input shape: {input_ids.shape}")
458
+ print(f" Target shape: {target_ids.shape}")
459
+ print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
460
+ print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
461
+
462
+ if batch_idx >= 2: # Only test first few batches
463
+ break
464
+
465
+ test_time = time.time() - start_time
466
+ print(f"\nβœ“ Data loader test completed successfully!")
467
+ print(f" Processed {batch_count} batches in {test_time:.2f}s")
468
+ print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
469
+
470
+ return True
471
+
472
+ except Exception as e:
473
+ print(f"❌ Data loader test failed: {e}")
474
+ import traceback
475
+ traceback.print_exc()
476
+ return False
477
+
478
+
479
+ if __name__ == "__main__":
480
+ test_data_loader()