Patching flash-attn
Browse files- modeling_aero.py +75 -1
modeling_aero.py
CHANGED
|
@@ -30,9 +30,16 @@ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
|
|
| 30 |
from transformers.modeling_utils import PreTrainedModel
|
| 31 |
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
| 32 |
from transformers.utils import logging
|
|
|
|
| 33 |
|
| 34 |
from .configuration_aero import AeroConfig
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
logger = logging.get_logger(__name__)
|
| 37 |
|
| 38 |
|
|
@@ -78,6 +85,72 @@ class AeroCausalLMOutputWithPast(ModelOutput):
|
|
| 78 |
audio_hidden_states: Optional[torch.FloatTensor] = None
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
class AeroAudioMultiModalProjector(nn.Module):
|
| 83 |
def __init__(self, config: AeroConfig):
|
|
@@ -136,7 +209,8 @@ class AeroPreTrainedModel(PreTrainedModel):
|
|
| 136 |
class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
|
| 137 |
def __init__(self, config: AeroConfig):
|
| 138 |
super().__init__(config)
|
| 139 |
-
|
|
|
|
| 140 |
self.audio_tower_type = config.audio_config.model_type
|
| 141 |
self.audio_tower = AutoModel.from_config(config.audio_config)
|
| 142 |
self.audio_modal_projector = AeroAudioMultiModalProjector(config)
|
|
|
|
| 30 |
from transformers.modeling_utils import PreTrainedModel
|
| 31 |
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
| 32 |
from transformers.utils import logging
|
| 33 |
+
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioFlashAttention2
|
| 34 |
|
| 35 |
from .configuration_aero import AeroConfig
|
| 36 |
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from flash_attn import flash_attn_func
|
| 40 |
+
except ImportError:
|
| 41 |
+
print("flash_attn not installed. Please install flash-attn to use flash-attn for audio tower")
|
| 42 |
+
|
| 43 |
logger = logging.get_logger(__name__)
|
| 44 |
|
| 45 |
|
|
|
|
| 85 |
audio_hidden_states: Optional[torch.FloatTensor] = None
|
| 86 |
|
| 87 |
|
| 88 |
+
# Original Flash attn in transformers for Qwen2Audio Encoder is buggy
|
| 89 |
+
# patch the function with this one
|
| 90 |
+
def qwen2_audio_flash_attn_forward(
|
| 91 |
+
self,
|
| 92 |
+
hidden_states: torch.Tensor,
|
| 93 |
+
key_value_states= None,
|
| 94 |
+
past_key_value= None,
|
| 95 |
+
attention_mask = None,
|
| 96 |
+
layer_head_mask = None,
|
| 97 |
+
output_attentions: bool = False,
|
| 98 |
+
cache_position = None,
|
| 99 |
+
):
|
| 100 |
+
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
| 101 |
+
if output_attentions:
|
| 102 |
+
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
|
| 103 |
+
|
| 104 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 105 |
+
|
| 106 |
+
# get query proj
|
| 107 |
+
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
|
| 108 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 109 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 110 |
+
|
| 111 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
| 112 |
+
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
| 113 |
+
key_states = key_states.transpose(1, 2)
|
| 114 |
+
value_states = value_states.transpose(1, 2)
|
| 115 |
+
|
| 116 |
+
causal_mask = attention_mask
|
| 117 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 118 |
+
causal_mask = attention_mask[:, : key_states.shape[-2]]
|
| 119 |
+
|
| 120 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 121 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 122 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 123 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 124 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 125 |
+
|
| 126 |
+
input_dtype = query_states.dtype
|
| 127 |
+
if input_dtype == torch.float32:
|
| 128 |
+
if torch.is_autocast_enabled():
|
| 129 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 130 |
+
# Handle the case where the model is quantized
|
| 131 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 132 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 133 |
+
else:
|
| 134 |
+
target_dtype = self.q_proj.weight.dtype
|
| 135 |
+
|
| 136 |
+
query_states = query_states.to(target_dtype)
|
| 137 |
+
key_states = key_states.to(target_dtype)
|
| 138 |
+
value_states = value_states.to(target_dtype)
|
| 139 |
+
dropout=self.dropout if self.training else 0.0
|
| 140 |
+
attn_output = flash_attn_func(
|
| 141 |
+
query_states, key_states, value_states, dropout, softmax_scale=None, causal=self.is_causal
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
| 145 |
+
attn_output = self.out_proj(attn_output)
|
| 146 |
+
|
| 147 |
+
if not output_attentions:
|
| 148 |
+
attn_weights = None
|
| 149 |
+
|
| 150 |
+
return attn_output, attn_weights, None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
|
| 155 |
class AeroAudioMultiModalProjector(nn.Module):
|
| 156 |
def __init__(self, config: AeroConfig):
|
|
|
|
| 209 |
class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
|
| 210 |
def __init__(self, config: AeroConfig):
|
| 211 |
super().__init__(config)
|
| 212 |
+
if config._attn_implementation == "flash_attention_2":
|
| 213 |
+
Qwen2AudioFlashAttention2.forward = qwen2_audio_flash_attn_forward
|
| 214 |
self.audio_tower_type = config.audio_config.model_type
|
| 215 |
self.audio_tower = AutoModel.from_config(config.audio_config)
|
| 216 |
self.audio_modal_projector = AeroAudioMultiModalProjector(config)
|