from functools import partial from typing import Any, Callable, Literal, Optional import torch from transformers.cache_utils import Cache from transformers.configuration_utils import PretrainedConfig from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLDecoderLayer, Qwen2_5_VLFlashAttention2, rotate_half, ) from .casa_attention import CASAAttention, CASAAttentionHandler from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig class QwenCASAAttentionHandler(CASAAttentionHandler): """Overrides CASAAttention with the right pos embedding computation for Qwen""" def __init__( self, *args: Any, get_rope_index: Callable | None = None, grid_thw: torch.Tensor | None = None, position_ids_offset: int = 0, **kwargs: Any, ): assert get_rope_index is not None, "get_rope_index should be given for QwenCASA" self.get_rope_index = partial(get_rope_index, image_grid_thw=grid_thw) self.position_ids_offset = position_ids_offset super().__init__(*args, **kwargs) def compute_position_embeddings( self, rope_fn: Callable, sample_lengths: list[int], dummy_for_dtype_and_device: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute info required for position embeddings. Can be overriden e.g. for Qwen""" # Here rope_fn is the "get_rope_index" function from the original mode dummy_input_ids = torch.zeros( (int(sum(sample_lengths)),), device=dummy_for_dtype_and_device.device, dtype=torch.long ) # Set image token ids dummy_input_ids[self.image_tokens_mask[:, 0]] = 151655 # required for the weird start of image tokens # Highly recommended to use pre and post image tokens with Qwen # Add vision start token ids (wherever a 151655 follows a 0) start_of_images = torch.logical_and( dummy_input_ids == 0, torch.nn.functional.pad(dummy_input_ids[1:] == 151655, (0, 1), value=0), ) dummy_input_ids[start_of_images] = 151652 # rebatch dummy input ids padding_side = "left" if self.attention_mask is not None else "right" s = list(torch.split(dummy_input_ids, self.full_batch_lengths)) mlen = max(_s.shape[0] for _s in s) trims = [mlen - _s.shape[0] for _s in s] dummy_input_ids = torch.stack( [ torch.nn.functional.pad( _s, ( trims[i] if padding_side == "left" else 0, trims[i] if padding_side == "right" else 0, ), value=-1, ) for i, _s in enumerate(s) ], dim=0, ) # We need to give attention map to rope_index in left padding attention_mask = torch.ones_like(dummy_input_ids) for i, t in enumerate(trims): if padding_side == "right": attention_mask[i, attention_mask.shape[-1] - t :] = 0 else: attention_mask[i, :t] = 0 # compute pos embeds shape (3, bs, seq) position_ids = ( self.get_rope_index(dummy_input_ids, attention_mask=attention_mask)[0] + self.position_ids_offset ) # Compute pos-ebemds and recover flattened unpadded shape cos, sin = rope_fn(dummy_for_dtype_and_device, position_ids) # reflatten seq if padding_side == "right": cos = torch.cat( [cos[:, i : i + 1, : cos.shape[2] - t, :] for i, t in enumerate(trims)], dim=2 ) sin = torch.cat( [sin[:, i : i + 1, : sin.shape[2] - t, :] for i, t in enumerate(trims)], dim=2 ) else: cos = torch.cat([cos[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2) sin = torch.cat([sin[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2) return cos, sin def get_position_embedding( self, key: Literal["q", "kv"], num_queries: int = 0, ) -> tuple[torch.Tensor, torch.Tensor] | None: if self.position_embeds is None: return None cos, sin = self.position_embeds # For Q, we only want the text-only posembeds if key == "q": cos, sin = ( cos[:, :, ~self.image_tokens_mask[:, 0]], sin[:, :, ~self.image_tokens_mask[:, 0]], ) elif key != "kv": raise ValueError(f"Unknown key for position embedding {key}") # Easy case: training or first step at inference: we use all the posembeds if num_queries == 0: return cos, sin # If num queries is given, we need to trim for *every sample in the batch* bls = self.full_batch_lengths if key == "kv" else self.batch_lengths cos = [x[:, :, -num_queries:] for x in torch.split(cos, bls, dim=2)] sin = [x[:, :, -num_queries:] for x in torch.split(sin, bls, dim=2)] return torch.cat(cos, dim=2), torch.cat(sin, dim=2) class QwenCASAAttention(CASAAttention): """A CASA Attention layer compatible with Qwen""" def __init__( self, config: Qwen2_5_VLCASAConfig, layer_idx: int | None, self_attn: torch.nn.Module | None = None, input_layernorm_fn: Callable | None = None, ): # Only adding this init for typing purposes for the config super().__init__(config, layer_idx, self_attn, input_layernorm_fn) # pyright: ignore[reportArgumentType] assert config.rope_scaling is not None self.mrope_section = config.rope_scaling["mrope_section"] * 2 def apply_position_embeddings( self, key: Literal["q", "kv"], x: torch.Tensor, # (batch, seq_len, num_heads, head_dim) casa_handler: CASAAttentionHandler | None, num_queries: int = 0, unsqueeze_dim: int = 1, ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim) """Apply position embeddings to query and key states""" if casa_handler is not None: posemb = casa_handler.get_position_embedding(key, num_queries=num_queries) if posemb is not None: x = x.transpose(1, 2).to(torch.float32) cos, sin = posemb cos = torch.cat( [m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1 ).unsqueeze(unsqueeze_dim) sin = torch.cat( [m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1 ).unsqueeze(unsqueeze_dim) x = (x * cos) + (rotate_half(x) * sin) return x.transpose(1, 2) return x def init_from_config_proj( self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig ) -> torch.nn.Linear: """Follows modeling_qwen2_5_vl.py initialization""" head_dim = config.hidden_size // config.num_attention_heads if key == "q": return torch.nn.Linear( config.hidden_size, config.num_attention_heads * head_dim, bias=True ) if key in {"k", "v"}: return torch.nn.Linear( config.hidden_size, config.num_key_value_heads * head_dim, bias=True ) if key == "o": return torch.nn.Linear( config.num_attention_heads * config.head_dim, config.hidden_size, bias=False ) raise NotImplementedError(f"Unknown key {key}") class Qwen2_5_VLAttention_CASA(Qwen2_5_VLFlashAttention2): """ Qwen Attention with extra CASA Attention layer """ def __init__( self, config: Qwen2_5_VLCASAConfig, layer_idx: Optional[int] = None, input_layernorm: torch.nn.Module | None = None, ): super().__init__(config, layer_idx) # pyright: ignore[reportArgumentType] self.casa_attn = QwenCASAAttention( config, layer_idx=layer_idx, self_attn=self, input_layernorm_fn=input_layernorm.forward if input_layernorm is not None else None, ) self.casa_attention_handler: CASAAttentionHandler | None = None @classmethod def from_qwen2_5_vl_attention( cls, attention: Qwen2_5_VLFlashAttention2, input_layernorm: torch.nn.Module | None ): """Init this layer from Qwen Attention layer""" layer_idx = attention.layer_idx assert layer_idx is not None new_attention = cls(attention.config, layer_idx=layer_idx, input_layernorm=input_layernorm) # pyright: ignore new_attention.load_state_dict(attention.state_dict(), strict=False) return new_attention def forward( # pyright: ignore[reportIncompatibleMethodOverride] self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ): casa_out: None | torch.Tensor = None if self.casa_attn is not None and self.config.xa_order in { "parallel", "ca_first", "instead", }: casa_out = self.casa_attn( hidden_states=hidden_states, casa_handler=self.casa_attention_handler, ) if self.config.xa_order == "instead": return casa_out, None, None if self.config.xa_order == "ca_first" and casa_out is not None: hidden_states, casa_out = casa_out, None attn_output, attn_weights, past_key_values = super().forward( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, ) if self.config.xa_order == "parallel" and casa_out is not None: attn_output = casa_out + attn_output return attn_output, attn_weights, past_key_values def add_casa_layers(m: torch.nn.Module, xa_layers: tuple[int, ...] | None): """Replace Attention layer by CASA Attention layer as needed""" if isinstance(m, Qwen2_5_VLDecoderLayer): layer_idx = m.self_attn.layer_idx assert layer_idx is not None if xa_layers is None or len(xa_layers) == 0 or layer_idx in xa_layers: m.self_attn = Qwen2_5_VLAttention_CASA.from_qwen2_5_vl_attention( m.self_attn, input_layernorm=m.input_layernorm )