gemma3-sms / train.py
qmaru's picture
Upload train.py
b9550e4 verified
from unsloth import FastModel
from unsloth.chat_templates import get_chat_template, train_on_responses_only
import torch
from trl.trainer.sft_config import SFTConfig
from trl.trainer.sft_trainer import SFTTrainer
from datasets import load_dataset
torch.backends.cudnn.benchmark = True
dtype = (
torch.bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else torch.float16
)
max_seq_length = 2048
def load_and_prepare_datasets(train_data, eval_data, tokenizer, seed=3407):
if train_data.endswith(".jsonl") or train_data.endswith(".json"):
train_dataset = load_dataset("json", data_files=train_data, split="train").shuffle(
seed=seed
)
else:
train_dataset = load_dataset(train_data, split="train").shuffle(seed=seed)
if eval_data.endswith(".jsonl") or eval_data.endswith(".json"):
eval_dataset = load_dataset("json", data_files=eval_data, split="train").shuffle(seed=seed)
else:
eval_dataset = load_dataset(eval_data, split="test").shuffle(seed=seed)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<bos>")
for convo in convos
]
return {"text": texts}
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
eval_dataset = eval_dataset.map(formatting_prompts_func, batched=True)
return train_dataset, eval_dataset
def create_trainer(
model,
tokenizer,
train_dataset,
eval_dataset,
output_dir,
per_device_train_batch_size,
gradient_accumulation_steps,
learning_rate,
num_train_epochs,
optim,
warmup_ratio=0.03,
logging_steps=5,
seed=3407,
):
sft_config = SFTConfig(
dataset_text_field="text",
dataset_num_proc=4,
packing=False,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
warmup_ratio=warmup_ratio,
num_train_epochs=num_train_epochs,
optim=optim,
lr_scheduler_type="cosine",
weight_decay=0.01,
output_dir=output_dir,
logging_steps=logging_steps,
save_strategy="epoch",
eval_strategy="epoch",
load_best_model_at_end=True,
report_to="none",
seed=seed,
bf16=(dtype == torch.bfloat16),
fp16=(dtype == torch.float16),
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=sft_config,
)
trainer = train_on_responses_only(
trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
)
return trainer
def train_full_finetuning(
model_name="unsloth/gemma-3-270m-it",
train_data="qmaru/gemma3-sms",
eval_data="qmaru/gemma3-sms",
max_seq_length=2048,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
learning_rate=2e-5,
warmup_ratio=0.03,
num_train_epochs=3,
output_dir="outputs_full_finetune",
model_output_dir="model_full",
logging_steps=5,
seed=3407,
):
model, tokenizer = FastModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=True,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
tokenizer = get_chat_template(tokenizer, chat_template="gemma3")
train_dataset, eval_dataset = load_and_prepare_datasets(train_data, eval_data, tokenizer, seed)
trainer = create_trainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir=output_dir,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
optim="adamw_torch",
warmup_ratio=warmup_ratio,
logging_steps=logging_steps,
seed=seed,
)
trainer_stats = trainer.train()
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)
model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="f16")
model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="q8_0")
return trainer_stats
def train_lora_finetuning(
model_name="unsloth/gemma-3-270m-it",
train_data="qmaru/gemma3-sms",
eval_data="qmaru/gemma3-sms",
max_seq_length=2048,
r=128,
lora_alpha=64,
lora_dropout=0.05,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
learning_rate=2e-4,
warmup_ratio=0.03,
num_train_epochs=3,
output_dir="outputs_lora",
model_output_dir="model_lora",
logging_steps=5,
seed=3407,
):
model, tokenizer = FastModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
model = FastModel.get_peft_model(
model,
r=r,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=seed,
use_rslora=False,
loftq_config=None,
)
tokenizer = get_chat_template(tokenizer, chat_template="gemma3")
train_dataset, eval_dataset = load_and_prepare_datasets(train_data, eval_data, tokenizer, seed)
trainer = create_trainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir=output_dir,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
optim="adamw_8bit",
warmup_ratio=warmup_ratio,
logging_steps=logging_steps,
seed=seed,
)
trainer_stats = trainer.train()
model.save_pretrained_merged(model_output_dir, tokenizer, save_method="merged_16bit")
model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="f16")
model.save_pretrained_gguf(model_output_dir, tokenizer, quantization_method="q8_0")
return trainer_stats
if __name__ == "__main__":
train_data = "qmaru/gemma3-sms"
eval_data = "qmaru/gemma3-sms"
# train_full_finetuning(train_data=train_data, eval_data=eval_data)
train_lora_finetuning(train_data=train_data, eval_data=eval_data)