from unsloth import FastModel from unsloth.chat_templates import get_chat_template, train_on_responses_only import torch from trl.trainer.sft_config import SFTConfig from trl.trainer.sft_trainer import SFTTrainer from datasets import load_dataset torch.backends.cudnn.benchmark = True dtype = ( torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 ) max_seq_length = 2048 def load_and_prepare_datasets(train_data, eval_data, tokenizer, seed=3407): if train_data.endswith(".jsonl") or train_data.endswith(".json"): train_dataset = load_dataset("json", data_files=train_data, split="train").shuffle( seed=seed ) else: train_dataset = load_dataset(train_data, split="train").shuffle(seed=seed) if eval_data.endswith(".jsonl") or eval_data.endswith(".json"): eval_dataset = load_dataset("json", data_files=eval_data, split="train").shuffle(seed=seed) else: eval_dataset = load_dataset(eval_data, split="test").shuffle(seed=seed) def formatting_prompts_func(examples): convos = examples["conversations"] texts = [ tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=False, ).removeprefix("") for convo in convos ] return {"text": texts} train_dataset = train_dataset.map(formatting_prompts_func, batched=True) eval_dataset = eval_dataset.map(formatting_prompts_func, batched=True) return train_dataset, eval_dataset def create_trainer( model, tokenizer, train_dataset, eval_dataset, output_dir, per_device_train_batch_size, gradient_accumulation_steps, learning_rate, num_train_epochs, optim, warmup_ratio=0.03, logging_steps=5, seed=3407, ): sft_config = SFTConfig( dataset_text_field="text", dataset_num_proc=4, packing=False, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, warmup_ratio=warmup_ratio, num_train_epochs=num_train_epochs, optim=optim, lr_scheduler_type="cosine", weight_decay=0.01, output_dir=output_dir, logging_steps=logging_steps, save_strategy="epoch", eval_strategy="epoch", load_best_model_at_end=True, report_to="none", seed=seed, bf16=(dtype == torch.bfloat16), fp16=(dtype == torch.float16), ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset, args=sft_config, ) trainer = train_on_responses_only( trainer, instruction_part="user\n", response_part="model\n", ) return trainer def train_full_finetuning( model_name="unsloth/gemma-3-270m-it", train_data="qmaru/gemma3-sms", eval_data="qmaru/gemma3-sms", max_seq_length=2048, per_device_train_batch_size=32, gradient_accumulation_steps=1, learning_rate=2e-5, warmup_ratio=0.03, num_train_epochs=3, output_dir="outputs_full_finetune", model_output_dir="model_full", logging_steps=5, seed=3407, ): model, tokenizer = FastModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=False, load_in_8bit=False, full_finetuning=True, ) if torch.cuda.is_available(): torch.cuda.empty_cache() tokenizer = get_chat_template(tokenizer, chat_template="gemma3") train_dataset, eval_dataset = load_and_prepare_datasets(train_data, eval_data, tokenizer, seed) trainer = create_trainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset, output_dir=output_dir, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, num_train_epochs=num_train_epochs, optim="adamw_torch", warmup_ratio=warmup_ratio, logging_steps=logging_steps, seed=seed, ) trainer_stats = trainer.train() model.save_pretrained(model_output_dir) tokenizer.save_pretrained(model_output_dir) model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="f16") model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="q8_0") return trainer_stats def train_lora_finetuning( model_name="unsloth/gemma-3-270m-it", train_data="qmaru/gemma3-sms", eval_data="qmaru/gemma3-sms", max_seq_length=2048, r=128, lora_alpha=64, lora_dropout=0.05, per_device_train_batch_size=16, gradient_accumulation_steps=2, learning_rate=2e-4, warmup_ratio=0.03, num_train_epochs=3, output_dir="outputs_lora", model_output_dir="model_lora", logging_steps=5, seed=3407, ): model, tokenizer = FastModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=False, load_in_8bit=False, full_finetuning=False, ) if torch.cuda.is_available(): torch.cuda.empty_cache() model = FastModel.get_peft_model( model, r=r, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", use_gradient_checkpointing="unsloth", random_state=seed, use_rslora=False, loftq_config=None, ) tokenizer = get_chat_template(tokenizer, chat_template="gemma3") train_dataset, eval_dataset = load_and_prepare_datasets(train_data, eval_data, tokenizer, seed) trainer = create_trainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset, output_dir=output_dir, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, num_train_epochs=num_train_epochs, optim="adamw_8bit", warmup_ratio=warmup_ratio, logging_steps=logging_steps, seed=seed, ) trainer_stats = trainer.train() model.save_pretrained_merged(model_output_dir, tokenizer, save_method="merged_16bit") model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="f16") model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="q8_0") return trainer_stats if __name__ == "__main__": train_data = "qmaru/gemma3-sms" eval_data = "qmaru/gemma3-sms" # train_full_finetuning(train_data=train_data, eval_data=eval_data) train_lora_finetuning(train_data=train_data, eval_data=eval_data)