Add assistant mask support to Qwen3-8B

#14
by waleko - opened

Enable Assistant Token Masking for Qwen3-8B

This pull request introduces support for assistant token masking in Qwen models by incorporating the {% generation %} tag within the chat template.

HuggingFace Transformers supports returning a mask of the tokens generated by the assistant in the return_assistant_tokens_mask argument of tokenizer.apply_chat_template (see huggingface/transformers#30650). Unfortunately, a lot of LLMs don't support this feature yet even though it's been a year since it was added.

🛠️ Chat Template Proposed Change
--- tokenizer_config.json (original)
+++ tokenizer_config.json (modified)
@@ -40,14 +40,17 @@
                 {%- set content = content.split('</think>')[-1].lstrip('\n') %}
             {%- endif %}
         {%- endif %}
+
+        {{- '<|im_start|>' + message.role }}
+        {% generation %}
         {%- if loop.index0 > ns.last_query_index %}
             {%- if loop.last or (not loop.last and reasoning_content) %}
-                {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
+                {{- '<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
             {%- else %}
-                {{- '<|im_start|>' + message.role + '\n' + content }}
+                {{- content }}
             {%- endif %}
         {%- else %}
-            {{- '<|im_start|>' + message.role + '\n' + content }}
+            {{- content }}
         {%- endif %}
         {%- if message.tool_calls %}
             {%- for tool_call in message.tool_calls %}
@@ -68,7 +71,8 @@
                 {{- '}\n</tool_call>' }}
             {%- endfor %}
         {%- endif %}
-        {{- '<|im_end|>\n' }}
+        {{- '<|im_end|>' }}
+        {% endgeneration %}
     {%- elif message.role == "tool" %}
         {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
             {{- '<|im_start|>user' }}

Why This is Important

As an example, distinguishing between tokens generated by the assistant and those originating from the user or environment is critical for various advanced applications. A prime example is multi-turn Reinforcement Learning (RL) training.

Currently, in frameworks like VeRL, identifying actor-generated tokens often requires manual reconstruction from the model's output. With this change to chat template, this process should be significantly simplified by leveraging existing solutions and not reinventing the wheel.

It would be great if Qwen models supported this feature, as they are widely used in the RL community.

🚀 Usage Example

The following demonstrates how to retrieve the assistant token mask:

import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")

conversation = [
    {"role": "user", "content": "Hello assistant"},
    {"role": "assistant", "content": "Hello user"},
    {"role": "user", "content": "How are you?"},
    {"role": "assistant", "content": "I'm good"},
]

tokenized_output = tokenizer.apply_chat_template(
    conversation,
    return_assistant_tokens_mask=True,
    return_dict=True,
)

print("Tokenized Output with Assistant Mask:")
print(tokenized_output)

# BEFORE
# {'input_ids': [151644, 872, 198, 9707, 17847, 151645, 198, 151644, 77091, 198, 9707, 1196, 151645, 198, 151644, 872, 198, 4340, 525, 498, 30, 151645, 198, 151644, 77091, 198, 151667, 271, 151668, 271, 40, 2776, 1661, 151645, 198], 
#  'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
#  'assistant_masks': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# }

# AFTER
# {'input_ids': [151644, 872, 198, 9707, 17847, 151645, 198, 151644, 77091, 198, 9707, 1196, 151645, 198, 151644, 872, 198, 4340, 525, 498, 30, 151645, 198, 151644, 77091, 198, 151667, 271, 151668, 271, 40, 2776, 1661, 151645, 198], 
#  'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
#  'assistant_masks': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
# }

Visualizing the mask helps understand which parts of the input correspond to the assistant's generation:

Visualization

Testing

  • Verified template works with both tool and non-tool scenarios
  • Verified works with reasoning content

@waleko could you merge this? thanks!
ah I figured your are not affiliated with hf, sorry!

hope this will be approved asap

In case you want to visualize the mask, here is a snippet:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", revision="refs/pr/14")
tokenizer.chat_template = template

conversation = [
    {"role": "system", "content": "You are a friendly assistant"},
    {"role": "user", "content": "Hello assistant"},
    {"role": "assistant", "content": "Hello user"},
    {"role": "user", "content": "How are you?"},
    {"role": "assistant", "content": "Fine thanks, and you?"},
    {"role": "user", "content": "Great, thanks"},
]
input_text = tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)
tokenized_output = tokenizer.apply_chat_template(
    conversation,
    return_assistant_tokens_mask=True,
    return_dict=True,
)
print("Tokenized Output with Assistant Mask:")
print(tokenized_output)

from rich import print
from rich.text import Text

# Load tokenizer (use the tokenizer corresponding to your model

# Visualize using rich
text_visualization = Text()
tokens = tokenizer.convert_ids_to_tokens(tokenized_output.input_ids)
for token, mask in zip(tokens, tokenized_output.assistant_masks):
    color = "cyan" if mask else "white"
    text_visualization.append(token.replace('Ġ', ' ').replace("Ċ", "\n"), style=color)

print(text_visualization)

Thanks @uralik and @edbeeching !
@edbeeching , really appreciate the visualization snippet. I’ve also attached execution results so others can see how it behaves in practice.

I’m not affiliated with the Qwen team, so this PR has unfortunately been stale for a while. I believe this change is quite important, and I’d love to see it land in Qwen. Any help in getting it noticed by the maintainers would be very appreciated!
image.png

Thank you for your thoughtful contribution and detailed explanation of the benefits of assistant token masking for use cases like multi-turn RL training.

After careful consideration, we are unable to merge this change into the default chat template at this time. Our main concerns are:

  1. It's important to clarify that generated assistant tokens (as defined by return_assistant_tokens_mask) and masked tokens during training are not necessarily the same concept.

    The choice of what to mask (i.e., what part of the input should contribute to loss computation or policy update) is often a delicate and context-dependent matter, determined by the specific goals of the training pipeline that evolves with training objective (e.g., keeping user tokens in RFT), training steps, or data characteristics (e.g., masking the thinking block in assistant tokens, keeping system tokens in SFT).

    Including a specific masking behavior via the default chat template could lead to misalignment with downstream use cases and potentially mislead users about what exactly is being masked and why. Instead of enforcing a one-size-fits-all approach through the tokenizer, we believe it is more appropriate for such masking logic to remain flexible and controlled at the data processing or training framework level, where the full context and intent are known.

  2. The proposed change relies heavily on a specific HuggingFace Transformers feature (a custom Jinja template extension) unknown to standard Jinja template engines. Adopting this into the official Qwen tokenizer chat template could introduce inconsistencies or usability issues for users working outside of HuggingFace's ecosystem, e.g., GGUF-based frameworks.

As an alternative, we suggest providing your modified chat template as a separate file (e.g., chat_template_with_assistant_mask.jinja) and documenting its purpose and usage in the model card, including:

  • A description of what assistant token masking is and how to use it with return_assistant_tokens_mask=True
  • Supported frameworks and required versions in the ecosystem
  • Any limitations or caveats related to training vs. inference use

This approach allows interested users to opt-in to this specialized behavior while avoiding unintended side effects for others who may rely on the default template. It also keeps the core tokenizer behavior consistent and broadly compatible across different tools and platforms.

Thank you again for your understanding and contribution to the community.

Hi @jklj077 , thank you for the detailed review and explanations! Both of your points make a lot of sense, I totally understand the focus on consistency and avoiding breaking changes. Supporting this as an opt-in sounds like a great middle ground. If you decide to move forward with it and need anything from my side, just let me know, happy to help!

Thanks again for your time and all your work on Qwen!

Hi @waleko and @edbeeching , cheers for writing the modified template and visualisation code.

I've written another version, masking everything but the final assistant response -- useful if you want to create a contrived chat history and only train on the final assistant response.

I've uploaded both templates + visualisation to a github repo --> https://github.com/HarryMayne/qwen_3_chat_templates

The assistant mask support request makes a lot of sense for Qwen3-8B specifically given how the model handles thinking tokens. The core issue is that when you're doing SFT or RLHF on top of this model, you need to properly mask out the <|im_start|>assistant prefix tokens and, critically, decide what to do with the <think>...</think> block — whether you want to include those reasoning tokens in the loss computation or mask them out entirely. Most training frameworks don't handle this cleanly out of the box for models with interleaved reasoning traces.

For practical implementation, the cleanest approach is extending the chat template to emit a loss_mask field alongside input_ids, where you explicitly zero out the loss on system/user turns and optionally on the thinking segments. If you're using trl or a custom trainer, you can patch the DataCollatorForSeq2Seq to respect this mask. The tricky edge case with Qwen3-8B is the non-thinking mode (/no_think or enable_thinking=False) — your masking logic needs to handle both branches of the template, otherwise you'll get inconsistent training behavior depending on whether thinking was enabled during data generation.

One thing worth flagging if you're building multi-agent pipelines on top of this: when Qwen3-8B is acting as a subagent and you're logging or verifying which model generated which response segment, the assistant mask boundaries become load-bearing for attribution. We ran into this in AgentGraph when trying to verify that reasoning traces actually came from the model rather than being injected upstream — the mask tells you where the model's "voice" begins, which matters for trust scoring in orchestrated pipelines. Getting the mask right at the data layer pays dividends later.

The assistant mask question here is pretty relevant to how Qwen3-8B gets used in multi-turn agentic pipelines. The core issue is that without proper assistant masking during fine-tuning, the model computes loss over tokens it shouldn't — specifically the assistant prefix and any injected system context — which degrades the quality of instruction-following and can cause the model to "bleed" formatting artifacts into generation. For Qwen3-8B specifically, the chat template already encodes turn boundaries via <|im_start|> and <|im_end|> tokens, so the masking logic needs to correctly identify which spans correspond to assistant turns and zero out the loss there during SFT.

The practical fix is making sure your data collator respects the role boundaries when constructing labels. A lot of people use trl's DataCollatorForCompletionOnlyLM with a response template string, but with Qwen3's tokenization you sometimes get off-by-one errors if the response template token sequence doesn't align cleanly after tokenization. Worth logging the label tensors directly and verifying that -100 masking covers exactly the human/system turns and nothing bleeds into the assistant response spans.

One thing worth flagging for anyone running Qwen3-8B in multi-agent setups: if agents are calling each other and injecting intermediate outputs into the conversation history, the masking problem compounds — you end up with assistant-generated content appearing in positions the model treats as "user" context, which creates training distribution mismatch if you're doing any online or continual fine-tuning. This is actually an area where explicit agent identity tracking matters; at AgentGraph we've run into exactly this when tracing which message segments were produced by which agent role, and having that provenance makes it much easier to construct correct label masks programmatically rather than relying on heuristic string matching against role prefixes.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment