--- language: - tr license: llama3.1 base_model: meta-llama/Llama-3.1-8B-Instruct tags: - legal - turkish - llama-3.1 - fp8 - bfloat16 - mixed-precision - question-answering - fsdp-v2 - pytorch - distributed-training datasets: - newmindai/EuroHPC-Legal library_name: transformers pipeline_tag: text-generation model-index: - name: Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 results: [] --- ## **Abstract** This model was trained as part of our study for comparing **FSDP2 with bfloat16 precision** against **FSDP2 with FP8 mixed precision bfp16-fp8**. We used `meta-llama/Llama-3.1-8B-Instruct`. The model has been loaded using `torch_dtype = bfloat16` and for FP8 + FSDP2 compatibility the model has been wrap per-layer instead of whole model This helped to avoid dimension misalignment issues and during forward and backward passes `float8 variats` been used the default **Tensorwise** quantization scaling recipe and we setted the `pad_inner_dim` for automatically pad dimensions to be divisible by 16 which is required for FP8. ```python from torchao.float8 import ( convert_to_float8_training, Float8LinearConfig, precompute_float8_dynamic_scale_for_fsdp) config = Float8LinearConfig( pad_inner_dim=True, enable_fsdp_float8_all_gather=True) model = convert_to_float8_training(model, config=config) if use_fp8: for i, layer in enumerate(model.model.layers): fully_shard(layer, **fsdp_kwargs) fully_shard(model.model.embed_tokens, **fsdp_kwargs) fully_shard(model.lm_head, **fsdp_kwargs) ```` ## **Base Model Technical Specifications** - **Parameters**: 8 Billion - **Architecture Family**: Llama 3.1 - **Maximum Position Embeddings**: 131,072 - **Attention Heads**: 32 (`num_attention_heads`) - **Key-Value Heads**: 8 (`num_key_value_heads`) - **Hidden Layers**: 32 (`num_hidden_layers`) - **Hidden Size**: 4,096 (`hidden_size`) - **Intermediate Size**: 14,336 - **Vocabulary Size**: 128,256 - **Precision**: bfloat16 - **RoPE Scaling**: type `llama3`, factor = 8.0 - **RMS Norm Epsilon**: 1e-05 - **Activation**: SiLU ## **Training Methodology** ### *Training Configuration* - **Model**: `meta-llama/Llama-3.1-8B-Instruct` - **Sequence Length**: 4,096 (`seq_len`) - **Epochs**: 2 - **Per-Device Micro Batch Size**: 8 - **Gradient Accumulation**: 8 - **GPUs**: 4 (via `CUDA_VISIBLE_DEVICES=0,1,2,3`) - **dtype**: `bf16` && `fp8=true` - Weights: bfloat16 - Activations: float8 - **Optimizer**: AdamW - Learning Rate: 2e-5 - Weight Decay: 0.01 - Betas: (0.9, 0.95) - Epsilon: 1e-8 - **LR Scheduler**: Cosine; warmup = 10% (`warmup_ratio=0.1`) | also `warmup_steps=100` - **Max Grad Norm**: 1.0 - **Gradient Checkpointing**: Enabled - **Checkpointing**: every 10 steps; keep last 5; select best by `eval_loss` - **Logging**: every step to file; Weights & Biases in offline mode - **Seed**: 100 - **Distributed Training**: `torch.distributed.run` (8 nodes, multi-GPU) - FSDP2 (Optimized Fully Sharded Data Parallel) ### *Setups* - **Precision**: Used Half-precision bfloat16 as data type and for computation. - **Hardware**: HPC (EuroHPC/BSC-class) 8 nodes with 4 × NVIDIA H100 GPUs. - **Framework**: PyTorch with `torchrun` for distributed training. ### *Dependencies* | package | Version | |-------|--------| | Transformers | 4.57.1 | | torch | 2.9.0+cu128 | | accelerate | 0.14.1 | | datasets | 4.3.0 | | huggingface-hub | 0.36.0 | | tensorboard | 2.20.0 | | tensorboard-data-server | 0.7.2 | | wandb | 0.22.1 | ## Job Details | model | Job ID | Runtime (mins) | Nodes | GPUs | Node-hour | GPU-hour | micro-batch | batch-size | gradient_accumulation | total_batch_size | | ---------------------------------------- | -------- | -------------- | ----- | ---- | --------- | ---------- | ----------- | ---------- | --------------------- | ---------------- | | Llama-3.1-8B-Instruct_w16a8_rw | 31768103 | 115.75 | 1 | 4 | **1.929** | **7.716** | 2 | 2 | 4 | 32 | | Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp| 31837629 | 109.00 | 1 | 4 | **1.816** | **7.266** | 2 | 2 | 4 | 32 | | Llama-3.1-8B-Instruct-w16a8-mxtw | 31768031 | 64.00 | 1 | 4 | **1.066** | **4.266** | 2 | 2 | 4 | 32 | | Llama-3.1-8B-Instruct-w16a16-tw | 31768074 | 138.75 | 1 | 4 | **2,312** | **9,25** | 2 | 2 | 4 | 32 | | Llama-3.1-8B-Instruct-w16a8-1node-bs8 | 31768093 | 123.75 | 1 | 4 | **2.062** | **8,250** | 2 | 2 | 4 | 32 | | Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 | 31478433 | 31.75 | 4 | 4 | **2.117** | **8.467** | 4 | 4 | 8 | 512 | | Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 | 31478468 | 39.75 | 4 | 4 | **2.650** | **10.600** | 4 | 4 | 8 | 512 | | Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 | 31476914 | 22.00 | 8 | 4 | **2.933** | **11.733** | 4 | 4 | 8 | 1024 | | Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 | 31476844 | 23.50 | 8 | 4 | **3.133** | **12.533** | 4 | 4 | 8 | 1024 | | Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 | 31476914 | 22.00 | 8 | 4 | **2.933** | **11.733** | 4 | 8 | 8 | 1024 | | Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 | 31476844 | 23.50 | 8 | 4 | **3.133** | **12.533** | 4 | 8 | 8 | 1024 | | Llama-3.1-8B-Instruct-w16a8-rw_4nodes | 33477070 | 39.75 | 4 | 4 | **2.650** | **10.600** | 4 | 4 | 8 | 512 | | Llama-3.1-8B-Instruct-w16a8-rw-8nodes | 33476690 | 23.50 | 8 | 4 | **3.133** | **12.533** | 4 | 4 | 8 | 1024 | | Llama-3.1-8B-Instruct-w16a8-rw_with_gw_hp_4nodes | 33477179 | 37.43 | 4 | 4 | **2.495** | **9.982** | 4 | 4 | 8 | 512 | | Llama-3.1-8B-Instruct-w16a8-rw-with-gw-hp-8nodes | 33476618 | 22.13 | 8 | 4 | **2.951** | **11.802** | 4 | 4 | 8 | 1024 | ### *All 15-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations and fp8 recipes)* | perplexity metric results for bfp16 && bfp16-fp8 configurations | Accuracy metric results for bfp16 && bfp16-fp8 configurations | Loss metric results for bfp16 && bfp16-fp8 configurations | Memory allocation for bfp16 && bfp16-fp8 configurations | Utilization for bfp16 && bfp16-fp8 configurations | |:--:|:--:|:--:|:--:|:--:| | ![perp](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F683d4880e639f8d647355997%2Fij1hlr8E2qvdZM4uGC7lq.png) | ![acc](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F683d4880e639f8d647355997%2F7lO8mVKPnQQkyUTw8H6GA.png) | ![train_loss](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F683d4880e639f8d647355997%2FE73tvcC6u9VrvTIkznwU2.png) | ![memAlo](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F683d4880e639f8d647355997%2FNsHL_yaTtnjwD1e4EHcLP.png) | ![utils](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F683d4880e639f8d647355997%2F5mqF8xcRWuZdC_sGS9FCe.png) | | Model | Max Loss (train) | Min Loss (train) | Avg Loss (train) | Final Loss (train) | ± Std (train) | Max Loss (val) | Min Loss (val) | Avg Loss (val) | Final Loss (val) | ± Std (val) | | ---------------------------------------------- | ---------------- | ---------------- | ---------------- | ------------------ | ------------- | -------------- | -------------- | -------------- | ---------------- | ----------- | | Llama-3.1-8B-Instruct-w16a8-rw | 8 | 3.1682 | 0.5740 | 0.8118 | 0.6431 | 0.2746 | 1.0613 | 0.8394 | 0.8937 | 0.8394 | 0.0688 | | Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp| 8 | 3.1837 | 0.5763 | 0.8116 | 0.6420 | 0.2751 | 1.0599 | 0.8391 | 0.8933 | 0.8391 | 0.0685 | | Llama-3.1-8B-Instruct-w16a8-mxtw | 8 | 3.1983 | 0.5747 | 0.8115 | 0.6446 | 0.2758 | 1.0562 | 0.8384 | 0.8923 | 0.8384 | 0.0677 | | Llama-3.1-8B-Instruct-w16a16-tw | 8 | 3.1235 | 0.7203 | 0.9750 | 0.3344 | 0.7612 | 1.9113 | 0.8907 | 0.9831 | 0.1897 | 0.8907 | 312 | | Llama-3.1-8B-Instruct-w16a8-1node-bs8 | 8 | 3.1661 | 0.7261 | 0.9804 | 0.3374 | 0.7672 | 1.9230 | 0.8948 | 0.9867 | 0.1906 | 0.8951 | 312 | | Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 | 32 | 3.2452 | 0.7414 | 0.9665 | 0.4844 | 0.7504 | 1.0538 | 0.8382 | 0.8844 | 0.0725 | 0.8382 | 70 | | Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 | 32 | 3.2840 | 0.7478 | 0.9748 | 0.4905 | 0.7581 | 1.0701 | 0.8430 | 0.8922 | 0.0764 | 0.8430 | 70 | | Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 | 32 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 | 0.8977 | 35 | | Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 | 32 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 | 0.8992 | 35 | | Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 | 64 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 | 0.8977 | 35 | | Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 | 64 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 | 0.8992 | 35 | ## **Implementation** ### *Usage* **Note**: the final model has saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below: ```python from transformers import AutoModelForCausalLM, AutoTokenizer import torch model_name = "newmindai/Llama-3.1-8B-Instruct-w16a8-8nodes-bs64" dtype = torch.bfloat16 tok = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="auto" ) prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:" inputs = tok(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=256, do_sample=False ) print(tok.decode(out[0], skip_special_tokens=True)) ```` ## **Ethical Considerations and Disclaimers** * Research & development purposes only; not a substitute for professional legal counsel. * Users must ensure compliance with data protection and sector regulations. * Potential biases may exist in domain data and model outputs. ## **Model & Data Card Metadata** * **Total Parameters**: 8,030,261,248 * **Serialized Size (approx.)**: 16,060,522,496 bytes * **Config precision**: bfloat16 * **RoPE**: llama3 scaling, factor 8.0 ## **References and Citations** ### *Base Model* ```bibtex @misc{meta_llama31_8b_instruct, title={Llama 3.1 8B Instruct}, author={Meta AI}, year={2024}, howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}} } ``` ### *Training Dataset* ```bibtex @misc{euro_hpc_legal, title={EuroHPC-Legal}, author={newmindai}, year={2025}, howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}} } ```