mwirth7 commited on
Commit
6e0a042
·
verified ·
1 Parent(s): 4e8a89b

Model and Feature Extractor

Browse files
Files changed (6) hide show
  1. config.json +33 -0
  2. config.py +55 -0
  3. feature_extractor.py +132 -0
  4. model.py +452 -0
  5. model.safetensors +3 -0
  6. preprocessor_config.json +20 -0
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BirdMAEModel"
4
+ ],
5
+ "attn_drop_rate": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "config.BirdMAEConfig",
8
+ "AutoModel": "model.BirdMAEModel"
9
+ },
10
+ "depth": 12,
11
+ "drop_path_rate": 0.0,
12
+ "drop_rate": 0.0,
13
+ "embed_dim": 768,
14
+ "img_size_x": 512,
15
+ "img_size_y": 128,
16
+ "in_chans": 1,
17
+ "init_values": null,
18
+ "mlp_ratio": 4.0,
19
+ "norm_layer_eps": 1e-06,
20
+ "num_heads": 12,
21
+ "num_patches": 256,
22
+ "num_patches_x": 32,
23
+ "num_patches_y": 8,
24
+ "num_tokens": 257,
25
+ "patch_size": 16,
26
+ "pos_drop_rate": 0.0,
27
+ "pos_trainable": false,
28
+ "proj_drop_rate": 0.0,
29
+ "qk_norm": false,
30
+ "qkv_bias": true,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.38.0"
33
+ }
config.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch.nn as nn # For norm_layer type
3
+
4
+
5
+ class BirdMAEConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ img_size_x=512, # From provided config
11
+ img_size_y=128, # From provided config
12
+ patch_size=16, # From provided config
13
+ in_chans=1, # From provided config
14
+ embed_dim=768, # From provided config
15
+ depth=12, # From provided config
16
+ num_heads=12, # From provided config
17
+ mlp_ratio=4.0, # From provided config
18
+ pos_trainable=False, # From provided config
19
+ qkv_bias: bool = True,
20
+ qk_norm: bool = False,
21
+ init_values: float = None,
22
+ drop_rate=0.0, # Not explicitly in your MAE_Encoder init, but Block has it
23
+ # attn_drop_rate=0.0, # Not explicitly in your MAE_Encoder init, but Block has it
24
+ # drop_path_rate=0.0, # Not explicitly in your MAE_Encoder init, but Block has it
25
+ norm_layer_eps=1e-6, # Default for nn.LayerNorm
26
+ #cls_token=True, # Your MAE_Encoder uses self.cls_token
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+
31
+ self.img_size_x = img_size_x
32
+ self.img_size_y = img_size_y
33
+ self.patch_size = patch_size
34
+ self.in_chans = in_chans
35
+ self.embed_dim = embed_dim
36
+ self.depth = depth
37
+ self.num_heads = num_heads
38
+ self.mlp_ratio = mlp_ratio
39
+ self.pos_trainable = pos_trainable
40
+
41
+ self.qkv_bias = qkv_bias
42
+ self.qk_norm = qk_norm
43
+ self.init_values = init_values
44
+ self.drop_rate = drop_rate
45
+ self.pos_drop_rate = drop_rate
46
+ self.attn_drop_rate = drop_rate
47
+ self.drop_path_rate = drop_rate
48
+ self.proj_drop_rate = drop_rate
49
+ self.norm_layer_eps = norm_layer_eps
50
+
51
+ # Calculated properties (useful for initializing the model)
52
+ self.num_patches_x = img_size_x // patch_size
53
+ self.num_patches_y = img_size_y // patch_size
54
+ self.num_patches = self.num_patches_x * self.num_patches_y
55
+ self.num_tokens = self.num_patches + 1
feature_extractor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SequenceFeatureExtractor
2
+ from transformers.tokenization_utils_base import BatchEncoding
3
+ from transformers.feature_extraction_utils import BatchFeature
4
+ from torchaudio.compliance.kaldi import fbank
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+
9
+ from typing import Union, List
10
+ from transformers.utils import PaddingStrategy
11
+
12
+
13
+ class BirdMAEFeatureExtractor(SequenceFeatureExtractor):
14
+ _auto_class = "AutoFeatureExtractor"
15
+ model_input_names = ["input_values"]
16
+
17
+ def __init__(
18
+ self,
19
+ # process waveform
20
+ feature_size: int = 1,
21
+ sampling_rate: int = 32_000,
22
+ padding_value: float = 0.0,
23
+ return_attention_mask: bool = True,
24
+
25
+ # fbank
26
+ htk_compat: bool = True,
27
+ use_energy: bool = False,
28
+ window_type: str = "hanning",
29
+ num_mel_bins: int = 128,
30
+ dither: float = 0.0,
31
+ frame_shift: int = 10,
32
+
33
+ # pad and normalize
34
+ target_length: int = 512,
35
+ mean: float = -7.2,
36
+ std: float = 4.43,
37
+
38
+ **kwargs
39
+ ):
40
+ super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
41
+ # squence FE
42
+ self.feature_size = feature_size
43
+ self.sampling_rate = sampling_rate
44
+ self.padding_value = padding_value
45
+ self.return_attention_mask = return_attention_mask
46
+
47
+ # fbank
48
+ self.htk_compat = htk_compat
49
+ self.use_energy = use_energy
50
+ self.window_type = window_type
51
+ self.num_mel_bins = num_mel_bins
52
+ self.dither = dither
53
+ self.frame_shift = frame_shift
54
+
55
+ # pad and normalize
56
+ self.target_length = target_length
57
+ self.mean = mean
58
+ self.std = std
59
+
60
+ def __call__(self,
61
+ waveform_batch: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
62
+ padding: Union[bool, str, PaddingStrategy] = "max_length",
63
+ max_length: int | None = None,
64
+ truncation: bool = True,
65
+ return_tensors: str = "pt"
66
+ ):
67
+
68
+ if isinstance(waveform_batch, (list, np.ndarray)) and not isinstance(waveform_batch[0], (list, np.ndarray)):
69
+ waveform_batch = [waveform_batch]
70
+
71
+
72
+ waveform_batch = self._process_waveforms(waveform_batch, padding, truncation)
73
+
74
+ fbank_features = self._compute_fbank_features(waveform_batch["input_values"])
75
+
76
+ fbank_features = self._pad_and_normalize(fbank_features)
77
+
78
+ return fbank_features
79
+
80
+ def _process_waveforms(self,
81
+ waveforms,
82
+ padding: bool | str,
83
+ truncation: bool):
84
+ clip_duration = 5 # TODO this is the clip duration used in training
85
+ max_length = int(int(self.sampling_rate) * clip_duration)
86
+ waveform_encoded = BatchFeature({"input_values": waveforms})
87
+
88
+ waveform_batch = self.pad(
89
+ waveform_encoded,
90
+ padding=padding,
91
+ max_length=max_length,
92
+ truncation=truncation,
93
+ return_attention_mask=self.return_attention_mask
94
+ )
95
+ waveform_batch["input_values"] = torch.tensor(
96
+ waveform_batch["input_values"])
97
+ attention_mask = waveform_batch.get("attention_mask")
98
+
99
+ if attention_mask is not None:
100
+ waveform_batch["attention_mask"] = attention_mask
101
+
102
+ # add std
103
+ waveform_batch["input_values"] = waveform_batch["input_values"] - waveform_batch["input_values"].mean(axis=1, keepdims=True)
104
+ #waveform_batch["input_values"] = (waveform_batch["input_values"] - waveform_batch["input_values"].mean(axis=1, keepdims=True)) / (waveform_batch["input_values"].std(axis=1, keepdims=True) + 1e-8)
105
+ return waveform_batch
106
+
107
+ def _compute_fbank_features(self, waveforms):
108
+ fbank_features = [
109
+ fbank(
110
+ waveform.unsqueeze(0),
111
+ htk_compat=self.htk_compat,
112
+ sample_frequency=self.sampling_rate,
113
+ use_energy=self.use_energy,
114
+ window_type=self.window_type,
115
+ num_mel_bins=self.num_mel_bins,
116
+ dither=self.dither,
117
+ frame_shift=self.frame_shift
118
+ )
119
+ for waveform in waveforms
120
+ ]
121
+ return torch.stack(fbank_features)
122
+
123
+ def _pad_and_normalize(self, fbank_features):
124
+ difference = self.target_length - fbank_features[0].shape[0]
125
+ min_value = fbank_features.min()
126
+
127
+ if self.target_length > fbank_features.shape[0]:
128
+ padding = (0, 0, 0, difference)
129
+ fbank_features = F.pad(fbank_features, padding, value=min_value.item())
130
+
131
+ fbank_features = (fbank_features - self.mean) / (self.std * 2)
132
+ return fbank_features
model.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ import math
5
+ import numpy as np
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.utils import logging
9
+ from transformers.modeling_outputs import BaseModelOutput
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ from .config import BirdMAEConfig
14
+
15
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
16
+ """
17
+ embed_dim: output dimension for each position
18
+ pos: a list of positions to be encoded: size (M,)
19
+ out: (M, D)
20
+ """
21
+ assert embed_dim % 2 == 0
22
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
23
+ omega /= embed_dim / 2.
24
+ omega = 1. / 10000**omega # (D/2,)
25
+
26
+ pos = pos.reshape(-1) # (M,)
27
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
28
+
29
+ emb_sin = np.sin(out) # (M, D/2)
30
+ emb_cos = np.cos(out) # (M, D/2)
31
+
32
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
33
+ return emb
34
+
35
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
36
+ assert embed_dim % 2 == 0
37
+
38
+ # use half of dimensions to encode grid_h
39
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
40
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
41
+
42
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
43
+ return emb
44
+
45
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
46
+ """
47
+ grid_size: int of the grid height and width
48
+ return:
49
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
50
+ """
51
+ grid_h = np.arange(grid_size[0], dtype=np.float32) # grid size[0] = 8
52
+ grid_w = np.arange(grid_size[1], dtype=np.float32) # grid size[1] = 32
53
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
54
+ grid = np.stack(grid, axis=0) # 2,8,32
55
+
56
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) # 2,1,8.32
57
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
58
+ if cls_token:
59
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
60
+ return pos_embed # 267 (+cls) x 1024 (feature dim)
61
+
62
+ # From timm.models.weight_init
63
+ def _trunc_normal_(tensor, mean, std, a, b):
64
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
65
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
66
+ def norm_cdf(x):
67
+ # Computes standard normal cumulative distribution function
68
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
69
+
70
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
71
+ logging.warning("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
72
+ f"The distribution may be severely truncated. (Current mean: {mean}, std: {std}, [a, b]: [{a}, {b}])")
73
+
74
+ # Values are generated by using a truncated uniform distribution and
75
+ # then using the inverse CDF for the normal distribution.
76
+ # Get upper and lower cdf values
77
+ l = norm_cdf((a - mean) / std)
78
+ u = norm_cdf((b - mean) / std)
79
+
80
+ # Uniformly fill tensor with values from [l, u], then translate to
81
+ # [2l-1, 2u-1].
82
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
83
+
84
+ # Use inverse cdf transform for normal distribution to get truncated
85
+ # standard normal
86
+ tensor.erfinv_()
87
+
88
+ # Transform to proper mean, std
89
+ tensor.mul_(std * math.sqrt(2.))
90
+ tensor.add_(mean)
91
+
92
+ # Clamp to ensure it's in the proper range
93
+ tensor.clamp_(min=a, max=b)
94
+ return tensor
95
+
96
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
97
+ # type: (Tensor, float, float, float, float) -> Tensor
98
+ """Fills the input Tensor with values drawn from a truncated
99
+ normal distribution. The values are within :math:`[a, b]` interval.
100
+
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ """
108
+ with torch.no_grad():
109
+ return _trunc_normal_(tensor, mean, std, a, b)
110
+
111
+
112
+ # From timm.models.layers
113
+ import collections
114
+ from itertools import repeat
115
+ class DropPath(nn.Module):
116
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
117
+ """
118
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
119
+ super(DropPath, self).__init__()
120
+ self.drop_prob = drop_prob
121
+ self.scale_by_keep = scale_by_keep
122
+
123
+ def forward(self, x):
124
+ if self.drop_prob == 0. or not self.training:
125
+ return x
126
+ keep_prob = 1 - self.drop_prob
127
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
128
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
129
+ if keep_prob > 0.0 and self.scale_by_keep:
130
+ random_tensor.div_(keep_prob)
131
+ return x * random_tensor
132
+
133
+ def _ntuple(n):
134
+ def parse(x):
135
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
136
+ return tuple(x)
137
+ return tuple(repeat(x, n))
138
+ return parse
139
+
140
+ class Mlp(nn.Module):
141
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
142
+ """
143
+ def __init__(
144
+ self,
145
+ in_features,
146
+ hidden_features=None,
147
+ out_features=None,
148
+ act_layer=nn.GELU,
149
+ norm_layer=None,
150
+ bias=True,
151
+ drop=0.,
152
+ use_conv=False,
153
+ ):
154
+ super().__init__()
155
+ out_features = out_features or in_features
156
+ hidden_features = hidden_features or in_features
157
+ bias = _ntuple(2)(bias)
158
+ drop_probs = _ntuple(2)(drop)
159
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
160
+
161
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
162
+ self.act = act_layer()
163
+ self.drop1 = nn.Dropout(drop_probs[0])
164
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
165
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
166
+ self.drop2 = nn.Dropout(drop_probs[1])
167
+
168
+ def forward(self, x):
169
+ x = self.fc1(x)
170
+ x = self.act(x)
171
+ x = self.drop1(x)
172
+ x = self.norm(x)
173
+ x = self.fc2(x)
174
+ x = self.drop2(x)
175
+ return x
176
+
177
+
178
+ # From timm.models.vision_transformer
179
+ import torch.nn.functional as F
180
+ class Attention(nn.Module):
181
+ fused_attn: bool
182
+
183
+ def __init__(
184
+ self,
185
+ dim: int,
186
+ num_heads: int = 8,
187
+ qkv_bias: bool = False,
188
+ qk_norm: bool = False,
189
+ attn_drop: float = 0.,
190
+ proj_drop: float = 0.,
191
+ norm_layer: nn.Module = nn.LayerNorm,
192
+ ) -> None:
193
+ super().__init__()
194
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
195
+ self.num_heads = num_heads
196
+ self.head_dim = dim // num_heads
197
+ self.scale = self.head_dim ** -0.5
198
+ # self.fused_attn = use_fused_attn()
199
+
200
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
201
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
202
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
203
+ self.attn_drop = nn.Dropout(attn_drop)
204
+ self.proj = nn.Linear(dim, dim)
205
+ self.proj_drop = nn.Dropout(proj_drop)
206
+
207
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
208
+ B, N, C = x.shape
209
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
210
+ q, k, v = qkv.unbind(0)
211
+ q, k = self.q_norm(q), self.k_norm(k)
212
+
213
+ # if self.fused_attn:
214
+ x = F.scaled_dot_product_attention(
215
+ q, k, v,
216
+ dropout_p=self.attn_drop.p if self.training else 0.,
217
+ )
218
+ # else:
219
+ # q = q * self.scale
220
+ # attn = q @ k.transpose(-2, -1)
221
+ # attn = attn.softmax(dim=-1)
222
+ # attn = self.attn_drop(attn)
223
+ # x = attn @ v
224
+
225
+ x = x.transpose(1, 2).reshape(B, N, C)
226
+ x = self.proj(x)
227
+ x = self.proj_drop(x)
228
+ return x
229
+
230
+
231
+ # From timm.models.vision_transformer
232
+ class Block(nn.Module):
233
+ def __init__(
234
+ self,
235
+ dim: int,
236
+ num_heads: int,
237
+ mlp_ratio: float = 4.,
238
+ qkv_bias: bool = False,
239
+ qk_norm: bool = False,
240
+ proj_drop: float = 0.,
241
+ attn_drop: float = 0.,
242
+ init_values: float = None,
243
+ drop_path: float = 0.,
244
+ act_layer: nn.Module = nn.GELU,
245
+ norm_layer: nn.Module = nn.LayerNorm,
246
+ mlp_layer: nn.Module = Mlp,
247
+ ) -> None:
248
+ super().__init__()
249
+ self.norm1 = norm_layer(dim)
250
+ self.attn = Attention(
251
+ dim,
252
+ num_heads=num_heads,
253
+ qkv_bias=qkv_bias,
254
+ qk_norm=qk_norm,
255
+ attn_drop=attn_drop,
256
+ proj_drop=proj_drop,
257
+ norm_layer=norm_layer,
258
+ )
259
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
260
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
261
+
262
+ self.norm2 = norm_layer(dim)
263
+ self.mlp = mlp_layer(
264
+ in_features=dim,
265
+ hidden_features=int(dim * mlp_ratio),
266
+ act_layer=act_layer,
267
+ drop=proj_drop,
268
+ )
269
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
270
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
271
+
272
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
273
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
274
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
275
+ return x
276
+
277
+ # From timm.models.vision_transformer
278
+ class LayerScale(nn.Module):
279
+ def __init__(
280
+ self,
281
+ dim: int,
282
+ init_values: float = 1e-5,
283
+ inplace: bool = False,
284
+ ) -> None:
285
+ super().__init__()
286
+ self.inplace = inplace
287
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
288
+
289
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
290
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
291
+
292
+
293
+
294
+ class PatchEmbed_org(nn.Module):
295
+ """ Image to Patch Embedding
296
+ """
297
+ def __init__(self,
298
+ img_size: int | tuple[int, ...] = 224,
299
+ patch_size: int | tuple[int, ...] = 16,
300
+ in_chans=3,
301
+ embed_dim=768):
302
+ super().__init__()
303
+ img_size: tuple[int,int] = _ntuple(2)(img_size) # audio mae used: (target_length x 128) --> not sure why tbh
304
+ patch_size: tuple[int,int] = _ntuple(2)(patch_size)
305
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
306
+ self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) # number of patches height/width = 8/32
307
+ self.img_size = img_size
308
+ self.patch_size = patch_size
309
+ self.num_patches = num_patches
310
+
311
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
312
+
313
+ def forward(self, x):
314
+ B, C, H, W = x.shape #batch size, channels, height, width --> apparently sth else is expected???
315
+ # x = x.permute(0,1,3,2) ###????
316
+ x = self.proj(x) # 1, 1, 512, 128 -> 1, 768, 32, 8 (batch, 768 channel, 32 height, 8 width)
317
+ x = x.flatten(2) # 1, 768, 32, 8 -> 1, 768, 256
318
+ x = x.transpose(1, 2) # 1, 768, 256 -> 1, 256, 768
319
+ return x
320
+
321
+
322
+ # --- END OF NECESSARY TIMM/Custom internal module definitions ---
323
+ from functools import partial
324
+
325
+ class BirdMAEModel(PreTrainedModel):
326
+ config_class = BirdMAEConfig
327
+ base_model_prefix = "BirdMAE"
328
+ main_input_name = "input_values"
329
+ _auto_class = "AutoModel"
330
+ _keys_to_ignore_on_load_missing = ["fc_norm.weight", "fc_norm.bias"]
331
+
332
+ def __init__(self, config: BirdMAEConfig, **kwargs):
333
+ super().__init__(config)
334
+ self.config = config
335
+
336
+ # The norm_layer partial is defined within your original MAE_Encoder
337
+ norm_layer = partial(nn.LayerNorm, eps=config.norm_layer_eps) # Assuming 1e-6 as default
338
+
339
+ self.patch_embed = PatchEmbed_org(
340
+ img_size=(config.img_size_x, config.img_size_y), # (512, 128)
341
+ patch_size=config.patch_size,
342
+ in_chans=config.in_chans,
343
+ embed_dim=config.embed_dim
344
+ )
345
+
346
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
347
+ self.pos_embed = nn.Parameter(
348
+ torch.zeros(1, config.num_patches + 1, config.embed_dim),
349
+ requires_grad=config.pos_trainable
350
+ )
351
+
352
+ # Positional embedding initialization
353
+ if self.pos_embed.data.shape[1] == config.num_patches + 1:
354
+ pos_embed_np = get_2d_sincos_pos_embed_flexible(
355
+ self.pos_embed.shape[-1], # embedding dim
356
+ self.patch_embed.patch_hw, # (8, 32) for a 128x512 image with 16x16 patches
357
+ cls_token=True
358
+ )
359
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed_np).float().unsqueeze(0))
360
+ else:
361
+ logger.warning("Positional embedding shape mismatch. Will not initialize sin-cos pos embed.")
362
+
363
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)]
364
+ self.blocks = nn.ModuleList([
365
+ Block(
366
+ dim=config.embed_dim,
367
+ num_heads=config.num_heads,
368
+ mlp_ratio=config.mlp_ratio,
369
+ qkv_bias=config.qkv_bias,
370
+ qk_norm=config.qk_norm,
371
+ init_values=config.init_values,
372
+ proj_drop=config.proj_drop_rate,
373
+ attn_drop=config.attn_drop_rate,
374
+ drop_path=dpr[i],
375
+ norm_layer=norm_layer
376
+ )
377
+ for i in range(config.depth)
378
+ ])
379
+
380
+ self.pos_drop = nn.Dropout(p=config.pos_drop_rate)
381
+ self.norm = norm_layer(config.embed_dim)
382
+ self.fc_norm = norm_layer(config.embed_dim)
383
+ self.global_pool = kwargs.get("global_pool", "average")
384
+
385
+ trunc_normal_(self.cls_token, std=.02) # timm uses trunc_normal_
386
+
387
+ # used when model is initilized from scratch
388
+ def _init_weights(self, m):
389
+ if isinstance(m, nn.Linear):
390
+ trunc_normal_(m.weight, std=.02)
391
+ if m.bias is not None:
392
+ nn.init.constant_(m.bias, 0)
393
+ elif isinstance(m, nn.LayerNorm):
394
+ nn.init.constant_(m.bias, 0)
395
+ nn.init.constant_(m.weight, 1.0)
396
+ elif isinstance(m, nn.Conv2d): # From your original init_weights
397
+ w = m.weight.data
398
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
399
+
400
+ def forward(
401
+ self,
402
+ input_features: torch.Tensor, # This will be our spectrograms (B, C, H, W) -> (B, 1, 128, 512)
403
+ #attention_mask=None, # For padding (B, num_time_patches)
404
+ output_attentions: bool = False,
405
+ output_hidden_states: bool = None,
406
+ return_dict: bool = None,
407
+ ):
408
+ if len(input_features.shape) == 3:
409
+ input_features = input_features.unsqueeze(0)
410
+
411
+ output_attentions = output_attentions or self.config.output_attentions
412
+
413
+ if output_attentions:
414
+ NotImplementedError("output_attention is not yet supported")
415
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
416
+ return_dict = return_dict or self.config.use_return_dict
417
+
418
+ B, C, X, Y = input_features.shape
419
+ assert X == self.config.img_size_x, f"Expected image_size_x={self.config.img_size_x} but was {X}."
420
+ assert Y == self.config.img_size_y, f"Expected image_size_y={self.config.img_size_y} but was {Y}."
421
+
422
+ x = self.patch_embed(input_features) # Output: (B, num_patches, embed_dim) -> (B, 256, 768)
423
+
424
+ x = x + self.pos_embed[:, 1:, :]
425
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
426
+ cls_tokens = cls_token.expand(B, -1, -1)
427
+ x = torch.cat((cls_tokens, x), dim=1)
428
+ x = self.pos_drop(x)
429
+
430
+ all_hidden_states = (x,) if output_hidden_states else None
431
+ for blk in self.blocks:
432
+ x = blk(x)
433
+ if output_hidden_states:
434
+ all_hidden_states = all_hidden_states + (x,)
435
+
436
+ if self.global_pool == "average":
437
+ x = x[:, 1:, :].mean(dim=1)
438
+ pooled_output = self.fc_norm(x)
439
+ elif self.global_pool == "cls":
440
+ x = self.norm(x)
441
+ pooled_output = x[:, 0]
442
+ else:
443
+ raise ValueError(f"Invalid global pool type: {self.global_pool}")
444
+
445
+ if not return_dict:
446
+ return (pooled_output,) + (all_hidden_states if output_hidden_states else ()) + (None,)
447
+
448
+ return BaseModelOutput(
449
+ last_hidden_state=pooled_output,
450
+ hidden_states=all_hidden_states,
451
+ attentions=None
452
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:402103ecc81787ca6317ad78377fb38a912b3548f23a18f424d7a0dc31f754c5
3
+ size 341826344
preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoFeatureExtractor": "feature_extractor.BirdMAEFeatureExtractor"
4
+ },
5
+ "dither": 0.0,
6
+ "feature_extractor_type": "BirdMAEFeatureExtractor",
7
+ "feature_size": 1,
8
+ "frame_shift": 10,
9
+ "htk_compat": true,
10
+ "mean": -7.2,
11
+ "num_mel_bins": 128,
12
+ "padding_side": "right",
13
+ "padding_value": 0.0,
14
+ "return_attention_mask": true,
15
+ "sampling_rate": 32000,
16
+ "std": 4.43,
17
+ "target_length": 512,
18
+ "use_energy": false,
19
+ "window_type": "hanning"
20
+ }