# Distillation Trainer

## Overview

The Distillation Trainer implements on-policy knowledge distillation as described in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.

> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution.

The `DistillationTrainer` is designed for distilling teacher models of all sizes into smaller students efficiently. It extends the ideas from the `GKDTrainer` with three key optimizations:

1. **Generation buffer** – decouples the training microbatch size from the generation batch size, letting vLLM batch many prompts in a single call across gradient accumulation steps. This alone can speed up training by up to 40x.
2. **Teacher server support** – moves the teacher to an external vLLM server so it does not need to fit on the same GPUs as the student.
3. **Binary-encoded logprob payloads** – packs log-probabilities into base64-encoded NumPy arrays instead of nested JSON lists, shrinking transfer payloads by ~5x.

> [!NOTE]
> The Distillation Trainer is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on.

## Quick start

```python
from datasets import load_dataset
from trl.experimental.distillation import DistillationConfig, DistillationTrainer

# 1. Load dataset and format as prompt-only chat messages
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(
    lambda x: {"messages": [{"role": "user", "content": x["question"]}]},
    remove_columns=dataset.column_names,
)

# 2. Configure distillation
config = DistillationConfig(
    output_dir="results/distill-qwen-gsm8k",
    num_train_epochs=1,
    bf16=True,
    save_strategy="no",
    # Distillation
    lmbda=1.0,                      # fully on-policy (student generates)
    beta=1.0,                       # reverse KL
    # Teacher
    teacher_model_init_kwargs={"torch_dtype": "bfloat16"},
)

# 3. Train
trainer = DistillationTrainer(
    model="Qwen/Qwen2.5-1.5B-Instruct",
    teacher_model="Qwen/Qwen2.5-7B-Instruct",
    args=config,
    train_dataset=dataset,
)
trainer.train()
trainer.save_model()
```

## Usage tips

The [experimental.distillation.DistillationTrainer](/docs/trl/v1.3.0/en/distillation_trainer#trl.experimental.distillation.DistillationTrainer) needs three key parameters set via [experimental.distillation.DistillationConfig](/docs/trl/v1.3.0/en/distillation_trainer#trl.experimental.distillation.DistillationConfig):

* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, training is fully off-policy (dataset completions only). When `lmbda=1.0`, training is fully on-policy (student generates all completions). For values in between, each gradient accumulation slice is randomly assigned as on- or off-policy based on `lmbda`.
* `beta`: controls the interpolation in the Generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while `beta=1.0` approximates reverse KL divergence. Values in between interpolate.
* `loss_top_k`: number of top tokens to use for the KL/JSD loss. Set to `0` for exact full-vocabulary computation (local teacher only), or `> 0` for a top-k approximation. See more about top-k with external teacher server below.

### On-policy vs. off-policy

Setting `lmbda=1.0` (fully on-policy) generally outperforms off-policy distillation because the student learns from its own mistakes rather than imitating trajectories it may never produce. The generation buffer ensures on-policy training stays efficient: prompts across gradient accumulation steps are batched into a single vLLM call.

### Using an external teacher server

For teachers that do not fit on training GPUs (e.g., 100B+ parameters), host the teacher on a separate vLLM server and set `use_teacher_server=True` with `teacher_model_server_url`:

```python
config = DistillationConfig(
    output_dir="distilled-model",
    use_teacher_server=True,
    teacher_model_server_url="http://teacher-host:8000",
    loss_top_k=1,       # required with teacher server when beta > 0
    beta=1.0,
    lmbda=1.0,
)

trainer = DistillationTrainer(
    model="Qwen/Qwen3-4B",
    args=config,
    train_dataset=dataset,
)
trainer.train()
```

When using the teacher server:
- `loss_top_k` must be `> 0` when `beta=0.0` (forward KL)
- `loss_top_k` must be exactly `1` when `beta > 0` (reverse KL or JSD)
- `reverse_kl_top_1_mode="argmax"` is not supported
- Liger kernel is not supported

### Expected dataset type

The dataset should be formatted as a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset:

```python
{"messages": [{"role": "user", "content": "What color is the sky?"},
              {"role": "assistant", "content": "It is blue."}]}
```

When using fully on-policy distillation (`lmbda=1.0`), the assistant turn can be omitted since the student will generate its own completions:

```python
{"messages": [{"role": "user", "content": "What color is the sky?"}]}
```

## Example script

Use [`trl/experimental/distillation/distillation.py`](https://github.com/huggingface/trl/blob/main/trl/experimental/distillation/distillation.py) to launch distillation training from the command line. The script supports full training, mixed on/off-policy, and LoRA via the standard `ModelConfig` flags.

```bash
# Full training (off-policy only, lmbda=0):
python trl/experimental/distillation/distillation.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
    --dataset_name trl-lib/chatbot_arena_completions \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --lmbda 0.0 \
    --output_dir distilled-model \
    --num_train_epochs 1
```

```bash
# Mixed on/off-policy (lmbda=0.5):
python trl/experimental/distillation/distillation.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
    --dataset_name trl-lib/chatbot_arena_completions \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --lmbda 0.5 \
    --beta 0.5 \
    --output_dir distilled-model \
    --num_train_epochs 1
```

## DistillationTrainer[[trl.experimental.distillation.DistillationTrainer]]

#### trl.experimental.distillation.DistillationTrainer[[trl.experimental.distillation.DistillationTrainer]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/trl/experimental/distillation/distillation_trainer.py#L356)

Trainer for knowledge distillation from a teacher model to a student model.

Supports:
- Generalized JSD loss (forward KL, reverse KL, or interpolated JSD via `beta`)
- On-policy / off-policy mixing via `lmbda` (buffered across gradient accumulation)
- Local teacher model or external teacher via vLLM server
- Student on-policy generation via vLLM or model.generate()
- Liger kernel for memory-efficient fused JSD loss

traintrl.experimental.distillation.DistillationTrainer.trainhttps://github.com/huggingface/trl/blob/v1.3.0/transformers/trainer.py#L1325[{"name": "resume_from_checkpoint", "val": ": str | bool | None = None"}, {"name": "trial", "val": ": optuna.Trial | dict[str, Any] | None = None"}, {"name": "ignore_keys_for_eval", "val": ": list[str] | None = None"}]- **resume_from_checkpoint** (`str` or `bool`, *optional*) --
  If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a
  `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
  of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.
- **trial** (`optuna.Trial` or `dict[str, Any]`, *optional*) --
  The trial run or the hyperparameter dictionary for hyperparameter search.
- **ignore_keys_for_eval** (`list[str]`, *optional*) --
  A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  gathering predictions for evaluation during the training.0`~trainer_utils.TrainOutput`Object containing the global step count, training loss, and metrics.

Main training entry point.

**Parameters:**

resume_from_checkpoint (`str` or `bool`, *optional*) : If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.

trial (`optuna.Trial` or `dict[str, Any]`, *optional*) : The trial run or the hyperparameter dictionary for hyperparameter search.

ignore_keys_for_eval (`list[str]`, *optional*) : A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

**Returns:**

``~trainer_utils.TrainOutput``

Object containing the global step count, training loss, and metrics.
#### save_model[[trl.experimental.distillation.DistillationTrainer.save_model]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/transformers/trainer.py#L3752)

Will save the model, so you can reload it using `from_pretrained()`.

Will only save from the main process.
#### push_to_hub[[trl.experimental.distillation.DistillationTrainer.push_to_hub]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/transformers/trainer.py#L3999)

Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.

**Parameters:**

commit_message (`str`, *optional*, defaults to `"End of training"`) : Message to commit while pushing.

blocking (`bool`, *optional*, defaults to `True`) : Whether the function should return only when the `git push` has finished.

token (`str`, *optional*, defaults to `None`) : Token with write permission to overwrite Trainer's original args.

revision (`str`, *optional*) : The git revision to commit from. Defaults to the head of the "main" branch.

kwargs (`dict[str, Any]`, *optional*) : Additional keyword arguments passed along to `~Trainer.create_model_card`.

**Returns:**

The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
progress of the commit if `blocking=True`.

## DistillationConfig[[trl.experimental.distillation.DistillationConfig]]

#### trl.experimental.distillation.DistillationConfig[[trl.experimental.distillation.DistillationConfig]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/trl/experimental/distillation/distillation_config.py#L23)

Configuration class for the `DistillationTrainer`.

Extends [TrainingArguments](https://huggingface.co/docs/transformers/v5.6.2/en/main_classes/trainer#transformers.TrainingArguments) with parameters specific to knowledge distillation. This config is
independent of [SFTConfig](/docs/trl/v1.3.0/en/sft_trainer#trl.SFTConfig) — all necessary fields are declared here.

Using [HfArgumentParser](https://huggingface.co/docs/transformers/v5.6.2/en/internal/trainer_utils#transformers.HfArgumentParser) we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

