| |
| """ |
| MiniMind Max2 Quick Start Example |
| Demonstrates basic usage of the Max2 model. |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| import torch |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("MiniMind Max2 Quick Start") |
| print("=" * 60) |
|
|
| |
| from configs.model_config import get_config, estimate_params |
| from model import Max2ForCausalLM |
|
|
| |
| model_name = "max2-nano" |
| print(f"\n1. Creating {model_name} model...") |
|
|
| config = get_config(model_name) |
| model = Max2ForCausalLM(config) |
|
|
| |
| params = estimate_params(config) |
| print(f" Total parameters: {params['total_params_b']:.3f}B") |
| print(f" Active parameters: {params['active_params_b']:.3f}B") |
| print(f" Activation ratio: {params['activation_ratio']:.1%}") |
| print(f" Estimated size (INT4): {params['estimated_size_int4_gb']:.2f}GB") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float16 if device == "cuda" else torch.float32 |
| model = model.to(device=device, dtype=dtype) |
| print(f"\n2. Model loaded on {device} with {dtype}") |
|
|
| |
| print("\n3. Testing forward pass...") |
| batch_size, seq_len = 2, 64 |
| input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device) |
|
|
| model.eval() |
| with torch.no_grad(): |
| loss, logits, _, aux_loss = model(input_ids, labels=input_ids) |
|
|
| print(f" Input shape: {input_ids.shape}") |
| print(f" Output logits shape: {logits.shape}") |
| print(f" Loss: {loss:.4f}") |
| print(f" MoE auxiliary loss: {aux_loss:.6f}") |
|
|
| |
| print("\n4. Testing generation...") |
| prompt = torch.randint(0, config.vocab_size, (1, 10), device=device) |
|
|
| with torch.no_grad(): |
| generated = model.generate( |
| prompt, |
| max_new_tokens=20, |
| temperature=0.8, |
| top_k=50, |
| top_p=0.9, |
| do_sample=True, |
| ) |
|
|
| print(f" Prompt length: {prompt.shape[1]}") |
| print(f" Generated length: {generated.shape[1]}") |
| print(f" New tokens: {generated.shape[1] - prompt.shape[1]}") |
|
|
| |
| if device == "cuda": |
| memory_used = torch.cuda.max_memory_allocated() / 1024**3 |
| print(f"\n5. Peak GPU memory: {memory_used:.2f}GB") |
|
|
| print("\n" + "=" * 60) |
| print("Quick start complete!") |
| print("=" * 60) |
|
|
| |
| print("\nNext steps:") |
| print(" - Train: python scripts/train.py --model max2-lite --train-data your_data.jsonl") |
| print(" - Export: python scripts/export.py --model max2-nano --format onnx gguf") |
| print(" - See README.md for full documentation") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|