Cheeky Sparrow commited on
Commit
426874e
·
1 Parent(s): 2c6a5a0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/484366__spacejoe__bird-3.wav filter=lfs diff=lfs merge=lfs -text
37
+ assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3 filter=lfs diff=lfs merge=lfs -text
38
+ assets/nri-battlesounds.mp3 filter=lfs diff=lfs merge=lfs -text
39
+ assets/nri-GreenTreeFrogEvergladesNP.mp3 filter=lfs diff=lfs merge=lfs -text
40
+ assets/nri-StreamMUWO.mp3 filter=lfs diff=lfs merge=lfs -text
41
+ assets/esp_favicon.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/naturelm-audio-overiew.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/nri-SensationJazz.mp3 filter=lfs diff=lfs merge=lfs -text
44
+ assets/yell-YELLAMRO20160506SM3.mp3 filter=lfs diff=lfs merge=lfs -text
45
+ assets/yell-YELLFLBCSACR20075171.mp3 filter=lfs diff=lfs merge=lfs -text
46
+ assets/yell-YELLWolfvCar20160111T22ms2.mp3 filter=lfs diff=lfs merge=lfs -text
NatureLM/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Earth Species Project
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .config import Config
16
+ from .models.NatureLM import NatureLM
17
+ from .utils import generate_sample_batches, prepare_sample_waveforms
18
+
19
+ __all__ = ["Config", "NatureLM", "generate_sample_batches", "prepare_sample_waveforms"]
NatureLM/augmentations.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from NatureLM.utils import mel_frequencies
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class RevEcho(nn.Module):
15
+ """
16
+ Hacky Reverb but runs on GPU without slowing down training. This reverb adds a
17
+ succession of attenuated echos of the input signal to itself. Intuitively, the delay
18
+ of the first echo will happen after roughly 2x the radius of the room and is
19
+ controlled by `first_delay`. Then RevEcho keeps adding echos with the same delay and
20
+ further attenuation until the amplitude ratio between the last and first echo is
21
+ 1e-3. The attenuation factor and the number of echos to adds is controlled by RT60
22
+ (measured in seconds). RT60 is the average time to get to -60dB (n.b. volume is
23
+ measured over the squared amplitude so this matches the 1e-3 ratio).
24
+
25
+ At each call to RevEcho, `first_delay`, `initial` and `RT60` are sampled from their
26
+ range. Then, to prevent this reverb from being too regular, the delay time is
27
+ resampled uniformly within `first_delay +/- 10%`, as controlled by the `jitter`
28
+ parameter.
29
+
30
+ Finally, for a denser reverb, multiple trains of echos are added with different
31
+ jitter noises.
32
+
33
+ Args:
34
+ - initial: amplitude of the first echo as a fraction of the input signal. For
35
+ each sample, actually sampled from `[0, initial]`. Larger values means louder
36
+ reverb. Physically, this would depend on the absorption of the room walls.
37
+ - rt60: range of values to sample the RT60 in seconds, i.e. after RT60 seconds,
38
+ the echo amplitude is 1e-3 of the first echo. The default values follow the
39
+ recommendations of https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf,
40
+ Section 2.4. Physically this would also be related to the absorption of the
41
+ room walls and there is likely a relation between `RT60` and `initial`, which
42
+ we ignore here.
43
+ - first_delay: range of values to sample the first echo delay in seconds. The
44
+ default values are equivalent to sampling a room of 3 to 10 meters.
45
+ - repeat: how many train of echos with differents jitters to add. Higher values
46
+ means a denser reverb.
47
+ - jitter: jitter used to make each repetition of the reverb echo train slightly
48
+ different. For instance a jitter of 0.1 means the delay between two echos will
49
+ be in the range `first_delay +- 10%`, with the jittering noise being resampled
50
+ after each single echo.
51
+ - keep_clean: fraction of the reverb of the clean speech to add back to the
52
+ ground truth. 0 = dereverberation, 1 = no dereverberation.
53
+ - sample_rate: sample rate of the input signals.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ proba=0.5,
59
+ initial=0.3,
60
+ rt60=(0.3, 1.3),
61
+ first_delay=(0.01, 0.03),
62
+ repeat=3,
63
+ jitter=0.1,
64
+ keep_clean=0.1,
65
+ sample_rate=16000,
66
+ rng=None,
67
+ seed=42,
68
+ ):
69
+ super().__init__()
70
+
71
+ self.proba = proba
72
+ self.initial = initial
73
+ self.rt60 = rt60
74
+ self.first_delay = first_delay
75
+ self.repeat = repeat
76
+ self.jitter = jitter
77
+ self.keep_clean = keep_clean
78
+ self.sample_rate = sample_rate
79
+ self.seed = seed
80
+ self.rng = rng if rng is not None else random.Random(self.seed)
81
+
82
+ def _reverb(self, source, initial, first_delay, rt60):
83
+ """
84
+ Return the reverb for a single source.
85
+ """
86
+ length = source.shape[-1]
87
+ reverb = th.zeros_like(source)
88
+
89
+ for _ in range(self.repeat):
90
+ frac = 1 # what fraction of the first echo amplitude is still here
91
+ echo = initial * source
92
+ while frac > 1e-3:
93
+ # First jitter noise for the delay
94
+ jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
95
+ delay = min(1 + int(jitter * first_delay * self.sample_rate), length)
96
+
97
+ # Delay the echo in time by padding with zero on the left
98
+ echo = F.pad(echo[:, :, :-delay], (delay, 0))
99
+ reverb += echo
100
+
101
+ # Second jitter noise for the attenuation
102
+ jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
103
+ # we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
104
+ # i.e. log10(d) = -3 * first_ms / rt60, so that
105
+ attenuation = 10 ** (-3 * jitter * first_delay / rt60)
106
+ echo *= attenuation
107
+ frac *= attenuation
108
+
109
+ return reverb
110
+
111
+ def forward(self, samples):
112
+ if self.rng.random() >= self.proba:
113
+ return samples
114
+
115
+ raw_wav = samples.get("raw_wav", None)
116
+
117
+ # add channel dimension if not exist
118
+ if raw_wav.dim() == 2:
119
+ raw_wav = raw_wav.unsqueeze(1)
120
+
121
+ # Sample characteristics for the reverb
122
+ initial = self.rng.random() * self.initial
123
+ first_delay = self.rng.uniform(*self.first_delay)
124
+ rt60 = self.rng.uniform(*self.rt60)
125
+
126
+ reverb_wav = self._reverb(raw_wav, initial, first_delay, rt60)
127
+ raw_wav += self.keep_clean * reverb_wav
128
+
129
+ # remove channel dimension
130
+ if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
131
+ raw_wav = raw_wav.squeeze(1)
132
+
133
+ samples["raw_wav"] = raw_wav
134
+ return samples
135
+
136
+
137
+ class BandMask(nn.Module):
138
+ """
139
+ Maskes bands of frequencies. Similar to Park, Daniel S., et al.
140
+ "Specaugment: A simple data augmentation method for automatic speech recognition."
141
+ (https://arxiv.org/pdf/1904.08779.pdf) but over the waveform.
142
+ """
143
+
144
+ def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000, rng=None, seed=42):
145
+ """__init__.
146
+
147
+ :param maxwidth: the maximum width to remove
148
+ :param bands: number of bands
149
+ :param sample_rate: signal sample rate
150
+ """
151
+ super().__init__()
152
+ self.maxwidth = maxwidth
153
+ self.bands = bands
154
+ self.sample_rate = sample_rate
155
+ self.seed = seed
156
+ self.rng = rng if rng is not None else random.Random(self.seed)
157
+
158
+ def forward(self, samples):
159
+ raw_wav = samples.get("raw_wav", None)
160
+
161
+ # add channel dimension if not exist
162
+ if raw_wav.dim() == 2:
163
+ raw_wav = raw_wav.unsqueeze(1)
164
+
165
+ bands = self.bands
166
+ bandwidth = int(abs(self.maxwidth) * bands)
167
+ mels = mel_frequencies(bands, 40, self.sample_rate / 2) / self.sample_rate
168
+ low = self.rng.randrange(bands)
169
+ high = self.rng.randrange(low, min(bands, low + bandwidth))
170
+
171
+ filters = LowPassFilters([mels[low], mels[high]]).to(raw_wav.device)
172
+
173
+ low, midlow = filters(raw_wav)
174
+ # band pass filtering
175
+ out = raw_wav - midlow + low
176
+
177
+ # remove channel dimension
178
+ if out.dim() == 3 and out.shape[1] == 1:
179
+ out = out.squeeze(1)
180
+
181
+ samples["raw_wav"] = out
182
+ return samples
183
+
184
+
185
+ class Shift(nn.Module):
186
+ def __init__(self, shift=8192, same=False, rngth=None):
187
+ """
188
+ :param shift: randomly shifts the signals up to a given factor
189
+ :param same: shifts both clean and noisy files by the same factor
190
+ """
191
+ super().__init__()
192
+ self.shift = shift
193
+ self.same = same
194
+ self.rngth = rngth
195
+
196
+ def forward(self, samples):
197
+ raw_wav = samples.get("raw_wav", None)
198
+ batch, channels, length = raw_wav.shape
199
+ length = length - self.shift
200
+ if self.shift > 0:
201
+ offsets = th.randint(
202
+ self.shift, [1 if self.same else batch, 1, 1], device=raw_wav.device, generator=self.rngth
203
+ )
204
+ offsets = offsets.expand(-1, channels, -1)
205
+ indexes = th.arange(length, device=raw_wav.device)
206
+ import pdb
207
+
208
+ pdb.set_trace()
209
+ raw_wav = raw_wav.gather(2, indexes + offsets)
210
+ samples["raw_wav"] = raw_wav
211
+ return samples
212
+
213
+
214
+ class TimeScale(nn.Module):
215
+ """Fast time scale."""
216
+
217
+ def __init__(self, scale=2.0, target=1, rngnp=None, seed=42):
218
+ """
219
+ :param scale: randomly scales up to this maximum factor
220
+ """
221
+ super().__init__()
222
+ self.scale = scale
223
+ self.target = target
224
+ self.seed = seed
225
+ self.rngnp = rngnp if rngnp is not None else np.random.default_rng(seed=self.seed)
226
+
227
+ def forward(self, samples):
228
+ try:
229
+ raw_wav = samples.get("raw_wav")
230
+ except KeyError:
231
+ logger.error("Missing required key 'raw_wav' in samples dict")
232
+ raise
233
+
234
+ if "padding_mask" in samples:
235
+ masks = samples.get("padding_mask")
236
+ else:
237
+ masks = th.ones_like(raw_wav)
238
+
239
+ # add channel dimension if not exist
240
+ if raw_wav.dim() == 2:
241
+ raw_wav = raw_wav.unsqueeze(1)
242
+ masks = masks.unsqueeze(1)
243
+
244
+ # what to augment: noise, clean, or both
245
+ if self.target == -1:
246
+ targets = [i for i in range(raw_wav.shape[0])]
247
+ else:
248
+ targets = [self.target]
249
+
250
+ for t in targets:
251
+ signal = raw_wav[t]
252
+ scaling = np.power(self.scale, self.rngnp.uniform(-1, 1))
253
+ output_size = int(signal.shape[-1] * scaling)
254
+ ref = th.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
255
+
256
+ ref1 = ref.clone().type(th.int64)
257
+ ref2 = th.min(ref1 + 1, th.full_like(ref1, signal.shape[-1] - 1, dtype=th.int64))
258
+ r = ref - ref1.type(ref.type())
259
+ scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
260
+ scaled_masks = masks[t][..., ref1] * (1 - r) + masks[t][..., ref2] * r
261
+
262
+ # trim or zero pad to the original size
263
+ if scaled_signal.shape[-1] > signal.shape[-1]:
264
+ nframes_offset = (scaled_signal.shape[-1] - signal.shape[-1]) // 2
265
+ scaled_signal = scaled_signal[..., nframes_offset : nframes_offset + signal.shape[-1]]
266
+ scaled_masks = scaled_masks[..., nframes_offset : nframes_offset + signal.shape[-1]]
267
+ else:
268
+ nframes_diff = signal.shape[-1] - scaled_signal.shape[-1]
269
+ pad_left = int(np.random.uniform() * nframes_diff)
270
+ pad_right = nframes_diff - pad_left
271
+ scaled_signal = F.pad(
272
+ input=scaled_signal, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
273
+ )
274
+ scaled_masks = F.pad(
275
+ input=scaled_masks, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
276
+ )
277
+ raw_wav[t] = scaled_signal
278
+ masks[t] = scaled_masks
279
+
280
+ # remove channel dimension
281
+ if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
282
+ raw_wav = raw_wav.squeeze(1)
283
+ masks = masks.squeeze(1)
284
+
285
+ samples["raw_wav"] = raw_wav
286
+ samples["padding_mask"] = masks
287
+
288
+ return samples
289
+
290
+
291
+ class Flip(nn.Module):
292
+ def __init__(self, p=0.0, rngth=None):
293
+ super(Flip, self).__init__()
294
+
295
+ self.p = p
296
+ self.rngth = rngth
297
+
298
+ def forward(self, samples):
299
+ raw_wav = samples["raw_wav"]
300
+ if raw_wav.dim() > 2:
301
+ flip_mask = th.rand(raw_wav.shape[0], device=raw_wav.device, generator=self.rngth) <= self.p
302
+ raw_wav[flip_mask] = raw_wav[flip_mask].flip(-1)
303
+ else:
304
+ if th.rand(1, generator=self.rngth) <= self.p:
305
+ raw_wav = raw_wav.flip(0)
306
+ samples["raw_wav"] = raw_wav
307
+ return samples
308
+
309
+
310
+ class LowPassFilters(th.nn.Module):
311
+ """
312
+ Bank of low pass filters.
313
+
314
+ Args:
315
+ cutoffs (list[float]): list of cutoff frequencies, in [0, 1] expressed as `f/f_s` where
316
+ f_s is the samplerate.
317
+ width (int | None): width of the filters (i.e. kernel_size=2 * width + 1).
318
+ Default to `2 / min(cutoffs)`. Longer filters will have better attenuation
319
+ but more side effects.
320
+ Shape:
321
+ - Input: `(*, T)`
322
+ - Output: `(F, *, T` with `F` the len of `cutoffs`.
323
+ """
324
+
325
+ def __init__(self, cutoffs: list, width: int | None = None):
326
+ super().__init__()
327
+
328
+ self.cutoffs = cutoffs
329
+
330
+ if not width:
331
+ width = int(2 / min(cutoffs))
332
+ self.width = width
333
+
334
+ window = th.hamming_window(2 * width + 1, periodic=False)
335
+ t = np.arange(-width, width + 1, dtype=np.float32)
336
+ filters = []
337
+ for cutoff in cutoffs:
338
+ sinc = th.from_numpy(np.sinc(2 * cutoff * t))
339
+ filters.append(2 * cutoff * sinc * window)
340
+ self.register_buffer("filters", th.stack(filters).unsqueeze(1))
341
+
342
+ def forward(self, input):
343
+ *others, t = input.shape
344
+ input = input.view(-1, 1, t)
345
+ out = F.conv1d(input, self.filters, padding=self.width)
346
+ return out.permute(1, 0, 2).reshape(-1, *others, t)
347
+
348
+ def __repr__(self):
349
+ return "LossPassFilters(width={},cutoffs={})".format(self.width, self.cutoffs)
NatureLM/checkpoint_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for training utilities.
2
+
3
+ This module contains utility functions for training models. For example, saving model checkpoints.
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import tempfile
9
+ from typing import Any, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def maybe_unwrap_dist_model(model: nn.Module, use_distributed: bool) -> nn.Module:
18
+ return model.module if use_distributed else model
19
+
20
+
21
+ def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]:
22
+ """Get model state dict. Optionally drop untrained parameters to keep only those that require gradient.
23
+
24
+ Args:
25
+ model: Model to get state dict from
26
+ drop_untrained_params: Whether to drop untrained parameters
27
+
28
+ Returns:
29
+ dict: Model state dict
30
+ """
31
+ if not drop_untrained_params:
32
+ return model.state_dict()
33
+
34
+ param_grad_dict = {k: v.requires_grad for (k, v) in model.named_parameters()}
35
+ state_dict = model.state_dict()
36
+
37
+ for k in list(state_dict.keys()):
38
+ if k in param_grad_dict.keys() and not param_grad_dict[k]:
39
+ # delete parameters that do not require gradient
40
+ del state_dict[k]
41
+
42
+ return state_dict
43
+
44
+
45
+ def torch_save_to_bucket(save_obj: Any, save_path: Union[str, os.PathLike], compress: bool = True) -> None:
46
+ """Save an object directly to GCS bucket without intermediate disk storage.
47
+
48
+ Args:
49
+ save_obj: Object to save (usually model state dict or checkpoint)
50
+ save_path: Path to save in GCS bucket (must be gs:// path)
51
+ compress: Whether to use compression. Default: True
52
+ """
53
+ if not is_gcs_path(save_path):
54
+ raise ValueError("save_path must be a GCS path")
55
+
56
+ # save to a temporary local file and then upload to GCS
57
+ with tempfile.NamedTemporaryFile() as tmp:
58
+ torch.save(save_obj, tmp.name, _use_new_zipfile_serialization=compress)
59
+ try:
60
+ save_path.upload_from(tmp.name)
61
+ except Exception as e:
62
+ logger.error(f"Error saving to GCP bucket: {e}")
63
+ raise e
64
+
65
+
66
+ def save_model_checkpoint(
67
+ model: nn.Module,
68
+ save_path: Union[str, os.PathLike],
69
+ use_distributed: bool = False,
70
+ drop_untrained_params: bool = False,
71
+ **objects_to_save,
72
+ ) -> None:
73
+ """Save model checkpoint.
74
+
75
+ Args:
76
+ model (nn.Module): Model to save
77
+ output_dir (str): Output directory to save checkpoint
78
+ use_distributed (bool): Whether the model is distributed, if so, unwrap it. Default: False.
79
+ is_best (bool): Whether the model is the best in the training run. Default: False.
80
+ drop_untrained_params (bool): Whether to drop untrained parameters to save. Default: True.
81
+ prefix (str): Prefix to add to the checkpoint file name. Default: "".
82
+ extention (str): Extension to use for the checkpoint file. Default: "pth".
83
+ **objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
84
+ """
85
+ if not is_gcs_path(save_path) and not os.path.exists(os.path.dirname(save_path)):
86
+ raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")
87
+
88
+ model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
89
+ state_dict = get_state_dict(model_no_ddp, drop_untrained_params)
90
+ save_obj = {
91
+ "model": state_dict,
92
+ **objects_to_save,
93
+ }
94
+
95
+ logger.info("Saving checkpoint to {}.".format(save_path))
96
+
97
+ if is_gcs_path(save_path):
98
+ torch_save_to_bucket(save_obj, save_path)
99
+ else:
100
+ torch.save(save_obj, save_path)
NatureLM/config.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Earth Species Project
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from pathlib import Path
17
+ from typing import Any, Literal
18
+
19
+ import yaml
20
+ from pydantic import BaseModel, field_validator
21
+ from pydantic.v1.utils import deep_update
22
+ from pydantic_settings import BaseSettings, CliSettingsSource, YamlConfigSettingsSource
23
+
24
+
25
+ class OptimizerConfig(BaseModel, extra="forbid", validate_assignment=True):
26
+ max_epoch: int
27
+ warmup_steps: int
28
+ warmup_start_lr: float = -1
29
+ init_lr: float
30
+ min_lr: float
31
+ weight_decay: float
32
+ beta2: float = 0.999
33
+ max_grad_norm: float | None = None
34
+ max_grad_value: float | None = None
35
+ device: str = "cuda"
36
+
37
+
38
+ class AugmentationsConfig(BaseModel, extra="forbid", validate_assignment=True):
39
+ use_augmentation: bool = False
40
+
41
+ noise_prob: float = 0
42
+ noise_dirs: list[Path] | None = None
43
+ low_snr: float = -5
44
+ high_snr: float = 20
45
+ time_scale_prob: float = 0
46
+ time_scale: float = 1.2
47
+ mixup_prob: float = 0
48
+ mixup_count: int = 3
49
+ mask_audio_prob: float = 0
50
+
51
+
52
+ class RunConfig(BaseModel, extra="forbid", validate_assignment=True):
53
+ wandb_enabled: bool = True
54
+ amp: bool = False
55
+ seed: int
56
+ output_dir: Path
57
+ evaluate: bool
58
+ log_freq: int
59
+ epoch_based: bool
60
+ iters_per_epoch: int
61
+ accum_grad_iters: int
62
+ batch_size_train: int
63
+ batch_size_eval: int
64
+ num_workers: int
65
+ custom_metrics: bool
66
+ decode_ratio: float
67
+
68
+ device: Literal["cuda", "cpu"] = "cuda"
69
+ use_distributed: bool = False
70
+
71
+ world_size: int = 1
72
+ rank: int = 0
73
+ gpu: int | None = None
74
+ dist_backend: Literal["nccl"] = "nccl"
75
+ dist_url: str = "env://"
76
+
77
+ optims: OptimizerConfig
78
+ augmentations: AugmentationsConfig
79
+
80
+
81
+ class DatasetsConfig(BaseModel, extra="forbid", validate_assignment=True):
82
+ train_ann_path: Path
83
+ valid_ann_path: Path
84
+ test_ann_path: Path
85
+ audio_max_length_seconds: int
86
+
87
+ @field_validator("train_ann_path", "valid_ann_path", "test_ann_path", mode="after")
88
+ @classmethod
89
+ def check_files(cls, path: Path) -> Path:
90
+ if not path.exists():
91
+ raise ValueError(f"File {path} does not exist")
92
+ if path.suffix.lower() != ".jsonl":
93
+ raise ValueError(f"File {path} must be a JSONL file")
94
+ return path
95
+
96
+
97
+ class BeatsConfig(BaseModel, extra="forbid", validate_assignment=True):
98
+ input_patch_size: int = -1
99
+ embed_dim: int = 512
100
+ conv_bias: bool = False
101
+
102
+ encoder_layers: int = 12
103
+ encoder_embed_dim: int = 768
104
+ encoder_ffn_embed_dim: int = 3072
105
+ encoder_attention_heads: int = 12
106
+ activation_fn: str = "gelu"
107
+
108
+ layer_wise_gradient_decay_ratio: float = 0.6
109
+ layer_norm_first: bool = False
110
+ deep_norm: bool = True
111
+
112
+ dropout: float = 0.0
113
+ attention_dropout: float = 0.0
114
+ activation_dropout: float = 0.0
115
+ encoder_layerdrop: float = 0.05
116
+ dropout_input: float = 0.0
117
+
118
+ conv_pos: int = 128
119
+ conv_pos_groups: int = 16
120
+
121
+ relative_position_embedding: bool = True
122
+ num_buckets: int = 320
123
+ max_distance: int = 800
124
+ gru_rel_pos: bool = True
125
+
126
+ finetuned_model: bool = True
127
+ predictor_dropout: float = 0.0
128
+ predictor_class: int = 527
129
+
130
+
131
+ class GenerateConfig(BaseModel, extra="forbid", validate_assignment=True):
132
+ max_new_tokens: int
133
+ num_beams: int
134
+ do_sample: bool
135
+ min_length: int
136
+ temperature: float
137
+ repetition_penalty: float
138
+ length_penalty: float
139
+
140
+
141
+ class ModelConfig(BaseModel, extra="forbid", validate_assignment=True):
142
+ llama_path: Path
143
+ beats_path: Path | None = None
144
+ beats_cfg: BeatsConfig
145
+ ckpt: Path | None = None
146
+ freeze_beats: bool = True
147
+ use_audio_Qformer: bool = True
148
+ max_pooling: bool = False
149
+ downsample_factor: int = 4
150
+ freeze_audio_QFormer: bool = False
151
+ window_level_Qformer: bool = True
152
+ num_audio_query_token: int = 1
153
+ second_per_window: float = 0.333333
154
+ second_stride: float = 0.333333
155
+ audio_llama_proj_model: Path | None = None
156
+ freeze_audio_llama_proj: bool = False
157
+ device: str = "cuda"
158
+ lora: bool = True
159
+ lora_rank: int = 8
160
+ lora_alpha: int = 32
161
+ lora_dropout: float = 0.1
162
+ flash_attn: Literal["eager", "flash_attention_2"] = "eager"
163
+ prompt_template: str = ""
164
+ max_txt_len: int = 128
165
+ end_sym: str = "</s>"
166
+
167
+ @field_validator("beats_path", "audio_llama_proj_model", "ckpt", mode="before")
168
+ @classmethod
169
+ def detect_gcs_path(cls, value: Any) -> Any:
170
+ """Pydantic's automatic type conversion won't be able to deal with gs:// paths
171
+ so we need to manually detect and convert them to GSPath objects _before_
172
+ validation"""
173
+ return value
174
+
175
+ @field_validator("ckpt", "audio_llama_proj_model", mode="before")
176
+ @classmethod
177
+ def legacy_empty_str(cls, value: Any) -> Any:
178
+ """In some of our config files we use "" to indicate that we don't have
179
+ a checkpoint. We've now switched to using None for this in the Config model but
180
+ let's keep this validator for backwards compatibility so people don't have to
181
+ change their configs"""
182
+ if isinstance(value, str) and value == "":
183
+ return None
184
+ else:
185
+ return value
186
+
187
+ @classmethod
188
+ def from_yaml(cls, yaml_file: str | os.PathLike) -> "ModelConfig":
189
+ yaml_values = YamlConfigSettingsSource(cls, yaml_file=str(yaml_file))
190
+ return cls.model_validate(yaml_values())
191
+
192
+
193
+ class Config(BaseSettings, extra="forbid", validate_assignment=True):
194
+ model: ModelConfig
195
+ run: RunConfig | None = None
196
+ datasets: DatasetsConfig | None = None
197
+ generate: GenerateConfig | None = None
198
+
199
+ def pretty_print(self):
200
+ print(self.model_dump_json(indent=4))
201
+
202
+ @classmethod
203
+ def from_sources(cls, yaml_file: str | Path, cli_args: list[str] = []) -> "Config":
204
+ """Create a Config object from a YAML file and CLI arguments. If there are
205
+ any conflicts, the CLI arguments will take precedence over the YAML file."""
206
+
207
+ yaml_file = Path(yaml_file)
208
+ if not yaml_file.exists():
209
+ raise FileNotFoundError(f"Config file {yaml_file} does not exist")
210
+
211
+ yaml_values = YamlConfigSettingsSource(cls, yaml_file=yaml_file)
212
+ cli_values = CliSettingsSource(cls, cli_parse_args=["--" + opt for opt in cli_args])
213
+ final_values = deep_update(yaml_values(), cli_values())
214
+ return cls.model_validate(final_values)
215
+
216
+ def to_yaml(self, path: str | os.PathLike) -> None:
217
+ save_config_as_yaml(self, path)
218
+
219
+
220
+ def save_config_as_yaml(data: BaseModel, filepath: str | os.PathLike) -> None:
221
+ """
222
+ Pydantic supports serializing/exporting models to various formats (dict, json, etc)
223
+ but not to yaml. This function is a workaround for that limitation.
224
+ """
225
+
226
+ filepath = Path(filepath)
227
+
228
+ if filepath.exists():
229
+ raise FileExistsError(f"File {filepath} already exists")
230
+
231
+ # The mode="json" is required because otherwise yaml.same_dump() can't deal with
232
+ # Path|GSPath objects
233
+ with filepath.open("w") as f:
234
+ yaml.safe_dump(data.model_dump(mode="json"), f, sort_keys=False)
NatureLM/dataset.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Earth Species Project
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ """
17
+ Mixing examples.
18
+ Can mix:
19
+ - base: options-detection add: open-ended:
20
+ Take all open-ended labels. Add them to the options. Add them to the labels.
21
+ - base: open-ended, add: open-ended
22
+ Concatenate labels
23
+ """
24
+
25
+ import glob
26
+ import json
27
+ import os
28
+ import random
29
+ from collections import defaultdict
30
+ from pathlib import Path
31
+ from typing import Literal
32
+
33
+ import numpy as np
34
+ import soundfile as sf
35
+ import torch
36
+ from torch.nn.utils.rnn import pad_sequence
37
+ from torch.utils.data import Dataset
38
+
39
+ from NatureLM.utils import snr_scale, time_scale
40
+
41
+
42
+ def write_example_to_file(base_filename, audio, sr=16000, suffix="_output", save_dir="debug_outputs"):
43
+ """
44
+ Writes the audio tensor to a file for debugging or inspection purposes.
45
+
46
+ Args:
47
+ base_filename (str): The base name of the original file.
48
+ audio (torch.Tensor or numpy.ndarray): The audio waveform to save.
49
+ sr (int): Sampling rate of the audio (default: 16000 Hz).
50
+ suffix (str): Optional suffix to append to the filename.
51
+ save_dir (str): Directory where the files will be saved.
52
+ """
53
+ if isinstance(audio, torch.Tensor):
54
+ audio = audio.numpy() # Convert to numpy if necessary
55
+
56
+ # Ensure the save directory exists
57
+ os.makedirs(save_dir, exist_ok=True)
58
+
59
+ # Create the output file path
60
+ filename = f"{os.path.splitext(base_filename)[0]}{suffix}.wav"
61
+ output_path = os.path.join(save_dir, filename)
62
+
63
+ try:
64
+ # Write the audio to the file
65
+ sf.write(output_path, audio, sr)
66
+ print(f"Saved audio to {output_path}")
67
+ except Exception as e:
68
+ print(f"Failed to write audio to file: {e}")
69
+
70
+
71
+ # Example usage in your code
72
+ # write_example_to_file(os.path.basename(ann["path"]), audio, suffix="_ts")
73
+
74
+
75
+ def collater(samples):
76
+ """Collate samples into a batch.
77
+
78
+ Samples is a list of dictionaries, each containing the following keys:
79
+ - raw_wav: a list of tensors containing the raw audio waveform
80
+ - text: a list of strings containing the text
81
+ - task: a list of strings containing the task
82
+ - id: a list of strings containing the id
83
+ - prompt: a list of strings containing the prompt
84
+ - index: a list of integers containing the index
85
+
86
+ The indiviudal audio waveforms will be stacked along the batch dimension for easier
87
+ processing in the audio model. To keep which audio belongs to which sample, we add
88
+ the audio_chunk_sizes key to the batch dictionary.
89
+ """
90
+ flat_raw_wav = []
91
+ audio_chunk_sizes = []
92
+
93
+ for s in samples:
94
+ chunk_size = len(s["raw_wav"])
95
+ audio_chunk_sizes.append(chunk_size)
96
+ flat_raw_wav.extend(s["raw_wav"])
97
+ # raw_wav = [torch.from_numpy(a) for a in flat_raw_wav]
98
+ raw_wav = flat_raw_wav
99
+ raw_wav_length = torch.tensor([len(a) for a in raw_wav])
100
+ raw_wav = pad_sequence(raw_wav, batch_first=True, padding_value=0)
101
+ paddding_mask = torch.arange(raw_wav.size(1)).unsqueeze(0) >= raw_wav_length.unsqueeze(1)
102
+
103
+ text = [s["text"] for s in samples]
104
+ prompt = [s["prompt"] for s in samples]
105
+ task = [s["task"] for s in samples]
106
+ id = [s["id"] for s in samples]
107
+ index = [s["index"] for s in samples]
108
+
109
+ return {
110
+ "raw_wav": raw_wav,
111
+ "padding_mask": paddding_mask,
112
+ "text": text,
113
+ "task": task,
114
+ "id": id,
115
+ "prompt": prompt,
116
+ "index": index,
117
+ "audio_chunk_sizes": audio_chunk_sizes,
118
+ }
119
+
120
+
121
+ class NatureLMDataset(Dataset):
122
+ def __init__(
123
+ self,
124
+ ann_path: str | Path,
125
+ *,
126
+ max_length_seconds: int = 10,
127
+ cropping: Literal["random", "start"] | None = "random",
128
+ noise_prob: float = 0.0,
129
+ noise_dirs: list[str] | list[Path] | None = None,
130
+ low_snr: float = -5,
131
+ high_snr: float = 20,
132
+ time_scale_prob: float = 0.0,
133
+ time_scale: float = 1.2,
134
+ seed: int = 0,
135
+ mixup_prob: float = 0.0,
136
+ mixup_count: int = 3,
137
+ use_augmentation: bool = False,
138
+ mask_audio_prob: float = 0.0,
139
+ ):
140
+ super().__init__()
141
+
142
+ ann_path = Path(ann_path)
143
+
144
+ if not ann_path.exists():
145
+ raise FileNotFoundError(f"Dataset file {ann_path} not found")
146
+
147
+ try:
148
+ with open(ann_path, "r") as f:
149
+ data = json.load(f)
150
+ self.annotation = data["annotation"]
151
+ except (json.JSONDecodeError, KeyError):
152
+ with open(ann_path, "r") as f:
153
+ self.annotation = [json.loads(line) for line in f]
154
+
155
+ #### mixup related variables
156
+ ### hash table for tasks to sample the tasks faster
157
+ self.tasks = defaultdict(list)
158
+ for i, ann in enumerate(self.annotation):
159
+ if "task" in ann and "text" in ann and ann["text"] != "None" and "path" in ann:
160
+ self.tasks[ann["task"]].append(i)
161
+
162
+ self.mixup_tasks = {
163
+ task: []
164
+ for task in self.tasks.keys()
165
+ if task.endswith("simple-detection")
166
+ or task.endswith("multiple-detection") # Add more tasks after validating prompt mixing.
167
+ or task.endswith("sci-detection-random")
168
+ or task.endswith("common-detection-random")
169
+ }
170
+ for k in self.mixup_tasks.keys():
171
+ # whichever the base, only mix open-ended tasks.
172
+ if "sci-" in k:
173
+ self.mixup_tasks[k] = [
174
+ task
175
+ for task in self.mixup_tasks.keys()
176
+ if task.endswith("sci-simple-detection") or task.endswith("sci-multiple-detection")
177
+ ]
178
+ elif "common-" in k:
179
+ self.mixup_tasks[k] = [
180
+ task
181
+ for task in self.mixup_tasks.keys()
182
+ if task.endswith("common-simple-detection") or task.endswith("common-multiple-detection")
183
+ ]
184
+ else:
185
+ self.mixup_tasks[k] = [task for task in self.mixup_tasks.keys() if "common-" in task]
186
+
187
+ # print("num annotations", len(self.annotation))
188
+ # print("annotation 0", self.annotation[0])
189
+ # self.annotation = [a for a in self.annotation if "task" in a and "detection" not in a["task"]] # no detection... :(
190
+ self.max_length_seconds = max_length_seconds
191
+ self.cropping = cropping
192
+ self.use_augmentation = use_augmentation
193
+
194
+ ### noise augmentation
195
+ self.rng = random.Random(seed)
196
+ self.rngnp = np.random.default_rng(seed=seed)
197
+ self.noise_dirs = noise_dirs
198
+ self.noise_prob = noise_prob
199
+ self.noise_files = []
200
+ self.low_snr = low_snr
201
+ self.high_snr = high_snr
202
+ self.mask_audio_prob = mask_audio_prob
203
+ if noise_dirs is not None and len(self.noise_dirs) > 0 and self.use_augmentation:
204
+ for noise_dir in noise_dirs:
205
+ noise_from_dir = glob.glob(os.path.join(noise_dir, "*.wav"))
206
+ if len(noise_from_dir) < 3000:
207
+ noise_from_dir = noise_from_dir * 3
208
+ print("noise files from dir", noise_dir, len(noise_from_dir))
209
+ self.noise_files.extend(noise_from_dir)
210
+
211
+ ### mixup augmentation
212
+ self.mixup_prob = mixup_prob
213
+ self.mixup_count = mixup_count
214
+ # ### time scale augmentation
215
+ self.time_scale = time_scale
216
+ self.time_scale_prob = time_scale_prob
217
+ # tasks = set([annotation["task"] if "task" in annotation else "empty" for annotation in self.annotation])
218
+ print(":::all tasks:::", self.tasks.keys())
219
+ print("num examples", len(self.annotation))
220
+
221
+ def __len__(self):
222
+ return len(self.annotation)
223
+
224
+ def collater(self, samples):
225
+ return collater(samples)
226
+
227
+ def load_audio(self, audio_path, shift_allowed: bool, noise_allowed: bool):
228
+ audio, sr = sf.read(audio_path)
229
+ # assert sr == 16000
230
+ if sr != 16000:
231
+ print("other sr!", sr, audio_path)
232
+ if len(audio.shape) == 2: # stereo to mono
233
+ audio = audio.mean(axis=1)
234
+
235
+ ### time scale augmentation
236
+ if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
237
+ # print(f"{index} scaling audio")
238
+ # write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] )
239
+ audio = time_scale(torch.tensor(audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
240
+ # write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] , suffix='_ts')
241
+
242
+ # Randomly crop a max_length_seconds window if audio is longer than 10 seconds
243
+ if len(audio) > sr * self.max_length_seconds and self.cropping == "random":
244
+ max_start = len(audio) - sr * self.max_length_seconds
245
+ start = random.randint(0, max_start)
246
+ audio = audio[start : start + sr * self.max_length_seconds]
247
+ else: # no random cropping
248
+ audio = audio[: sr * self.max_length_seconds] # Truncate audio to at most max_length_seconds
249
+
250
+ ### noise augmentation
251
+ audio = torch.tensor(audio)
252
+ ### noise augmentation
253
+ if (
254
+ self.use_augmentation
255
+ and self.rng.random() < self.noise_prob
256
+ and len(self.noise_files) > 0
257
+ and noise_allowed
258
+ ):
259
+ # write_example_to_file(os.path.basename(ann["path"]), audio)
260
+ # print(f"{index} adding noise")
261
+ noise_file = self.rng.choice(self.noise_files)
262
+ if not os.path.exists(noise_file):
263
+ print(f"Warning: noise file {noise_file} does not exist")
264
+ else:
265
+ noise_audio, noise_sr = sf.read(noise_file)
266
+ assert noise_sr == 16000
267
+ if len(noise_audio.shape) == 2:
268
+ noise_audio = noise_audio.mean(axis=1)
269
+
270
+ noise_audio = torch.tensor(noise_audio)
271
+
272
+ ### repeat or trim to the audio size
273
+ if len(audio) > len(noise_audio):
274
+ if len(noise_audio) == 0:
275
+ print(
276
+ "----- Warning: Noise audio length is zero. ---------- ",
277
+ noise_file,
278
+ )
279
+ # Option 1: Skip noise augmentation by setting noise_audio to zero
280
+ noise_audio = torch.zeros_like(audio)
281
+ else:
282
+ nrepeats = int(np.maximum(2, np.ceil(len(audio) / len(noise_audio))))
283
+ noise_audio = noise_audio.repeat(nrepeats)
284
+ ### Randomly crop the noise file if it is too long
285
+ if len(noise_audio) > len(audio):
286
+ max_start = len(noise_audio) - len(audio)
287
+ start = random.randint(0, max_start)
288
+ noise_audio = noise_audio[start : start + len(audio)]
289
+
290
+ ### remix with specified snr
291
+ snr = self.rngnp.uniform(self.low_snr, self.high_snr)
292
+ snr = torch.tensor([snr])
293
+ noise_audio = snr_scale(audio, noise_audio, snr)
294
+ audio = audio + noise_audio
295
+
296
+ # write_example_to_file(os.path.basename(audio_path), audio, suffix='_noise')
297
+ if len(audio) > self.max_length_seconds * sr:
298
+ print("long audio", len(audio), len(noise_audio))
299
+ audio = audio[: self.max_length_seconds * sr]
300
+
301
+ # pad all audios to max_len_seconds in _getitem_ to ensure no padding inconsistencies.
302
+ if len(audio) < sr * self.max_length_seconds:
303
+ pad_size = sr * self.max_length_seconds - len(audio)
304
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
305
+
306
+ audio = torch.clamp(audio, -1.0, 1.0)
307
+
308
+ return audio
309
+
310
+ def _mix_labels(self, text, text_to_mix):
311
+ """
312
+ Given two comma-separated label strings (e.g., "gorilla, zebra"),
313
+ combine them without introducing duplicates. If either is "None",
314
+ return the other as-is (unless both are "None").
315
+ """
316
+ # If `text_to_mix` is explicitly "None", just return `text`.
317
+ if text_to_mix == "None":
318
+ return text
319
+
320
+ # If `text` is explicitly "None", just return `text_to_mix`.
321
+ if text == "None":
322
+ return text_to_mix
323
+
324
+ # Split both strings by comma, stripping whitespace
325
+ text_list = [item.strip() for item in text.split(",") if item.strip()]
326
+ text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
327
+
328
+ # Deduplicate: add only new items from text_to_mix_list
329
+ combined_set = set(text_list)
330
+ for item in text_to_mix_list:
331
+ if item not in combined_set:
332
+ text_list.append(item)
333
+ combined_set.add(item)
334
+
335
+ # If there's nothing left after deduplication, return "None".
336
+ if not text_list:
337
+ return "None"
338
+
339
+ # Rejoin them into a comma-separated string
340
+ return ", ".join(text_list)
341
+
342
+ def _mix_prompts(self, text, text_to_mix, prompt):
343
+ """
344
+ If the prompt is in the form:
345
+ "Which of these, if any, are present in the audio recording? option1, option2, ..."
346
+
347
+ 1. Parse out the question (before '?') and the list of prompt choices (after '?').
348
+ 2. Convert both `text` and `text_to_mix` into lists, checking for items not in the prompt.
349
+ 3. Append any missing answers to the prompt choices.
350
+ 4. Shuffle the choices.
351
+ 5. Reassemble and return the new prompt.
352
+
353
+ If the prompt does not follow the expected structure, it is returned unmodified.
354
+ """
355
+ # Split into two parts: question + choices
356
+ splitted = prompt.split("?")
357
+ if len(splitted) != 2:
358
+ # If we don't have exactly one question mark segment, just return the original prompt
359
+ return prompt
360
+
361
+ question = splitted[0].strip()
362
+ potential_choices_str = splitted[1].strip()
363
+
364
+ # Split the prompt choices
365
+ if not potential_choices_str:
366
+ prompt_choices = []
367
+ else:
368
+ prompt_choices = [c.strip() for c in potential_choices_str.split(",") if c.strip()]
369
+
370
+ # Parse `text`
371
+ text_list = [item.strip() for item in text.split(",") if item.strip()]
372
+
373
+ # Parse `text_to_mix`
374
+ text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
375
+
376
+ # Add any new items from text_list to the prompt
377
+ for item in text_list:
378
+ if item not in prompt_choices:
379
+ prompt_choices.append(item)
380
+
381
+ # Add any new items from text_to_mix_list to the prompt
382
+ for item in text_to_mix_list:
383
+ if item not in prompt_choices:
384
+ prompt_choices.append(item)
385
+
386
+ # Shuffle consistently with self.rng
387
+ self.rng.shuffle(prompt_choices)
388
+
389
+ # Reassemble
390
+ new_prompt = question + "? " + ", ".join(prompt_choices)
391
+ return new_prompt
392
+
393
+ def _apply_mixup(self, prompt, audio, text, task, filename=None):
394
+ # mixup_applied = False
395
+ if (
396
+ self.use_augmentation and self.rng.random() < self.mixup_prob and task in self.mixup_tasks
397
+ # and text != "None" # Allow complex 'None' examples.
398
+ ):
399
+ # write_example_to_file(os.path.basename(ann["path"]), audio)
400
+ # print(f"{index} mixing up")
401
+ mixup_indices = []
402
+ for pair_task in self.mixup_tasks[task]:
403
+ mixup_indices.extend(self.tasks[pair_task])
404
+ # mixup_indices = mixup_indices.remove(index)
405
+
406
+ if len(mixup_indices) == 0:
407
+ print("No mixup partner found")
408
+ else:
409
+ ### choose n_mixup random partners
410
+ n_mixup = self.rng.randint(1, self.mixup_count)
411
+ mixup_indices = self.rng.sample(mixup_indices, n_mixup)
412
+ # print(f"Mixing up with indices {mixup_indices}")
413
+ for mixup_index in mixup_indices:
414
+ mixup_ann = self.annotation[mixup_index]
415
+ mixup_audio, _ = sf.read(mixup_ann["path"])
416
+ if len(mixup_audio.shape) == 2:
417
+ mixup_audio = mixup_audio.mean(axis=1)
418
+ mixup_audio = mixup_audio[: len(audio)]
419
+ if len(mixup_audio) < len(audio):
420
+ pad_size = len(audio) - len(mixup_audio)
421
+ mixup_audio = np.pad(mixup_audio, (0, pad_size), mode="constant")
422
+ mixup_audio = torch.from_numpy(mixup_audio).float()
423
+ lam = np.clip(self.rngnp.beta(1.0, 1.0), 0.1, 0.8)
424
+
425
+ # Mix the raw_wav
426
+ audio = lam * audio + (1 - lam) * mixup_audio
427
+
428
+ ### Mix the prompts if the labels are given in prompts
429
+ if text in prompt:
430
+ prompt = self._mix_prompts(text, mixup_ann["text"], prompt)
431
+
432
+ ### Mix the labels
433
+ text = self._mix_labels(text, mixup_ann["text"])
434
+
435
+ # mixup_applied = True
436
+
437
+ # DEBUG: If mixup was actually applied, save the final audio
438
+ # if mixup_applied and filename is not None:
439
+ # # Just add a suffix to the original filename to indicate mixup
440
+ # base_filename = os.path.basename(filename)
441
+ # write_example_to_file(
442
+ # base_filename=base_filename,
443
+ # audio=audio,
444
+ # sr=16000,
445
+ # suffix="_mixup",
446
+ # save_dir="mixup_outputs"
447
+ # )
448
+ # print(f"mixup for {filename}::: prompt {prompt} label {text}")
449
+
450
+ return prompt, audio, text
451
+
452
+ def _load_noise(self, shift_allowed: bool):
453
+ noise_file = self.rng.choice(self.noise_files)
454
+ noise_audio, noise_sr = sf.read(noise_file)
455
+ assert noise_sr == 16000, f"Expected noise sample rate 16000, got {noise_sr}"
456
+ if len(noise_audio.shape) == 2:
457
+ noise_audio = noise_audio.mean(axis=1)
458
+
459
+ # Time scale augmentation if applicable
460
+ if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
461
+ noise_audio = time_scale(torch.tensor(noise_audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
462
+
463
+ # Randomly crop or pad to match max_length_seconds
464
+ if len(noise_audio) > self.max_length_seconds * 16000 and self.cropping == "random":
465
+ max_start = len(noise_audio) - self.max_length_seconds * 16000
466
+ start = random.randint(0, max_start)
467
+ noise_audio = noise_audio[start : start + self.max_length_seconds * 16000]
468
+ else:
469
+ noise_audio = noise_audio[: self.max_length_seconds * 16000]
470
+
471
+ # Pad if needed
472
+ if len(noise_audio) < self.max_length_seconds * 16000:
473
+ pad_size = self.max_length_seconds * 16000 - len(noise_audio)
474
+ noise_audio = np.pad(noise_audio, (0, pad_size), mode="constant")
475
+
476
+ noise_audio = torch.tensor(noise_audio).float()
477
+ noise_audio = torch.clamp(noise_audio, -1.0, 1.0)
478
+ return noise_audio
479
+
480
+ def __getitem__(self, index):
481
+ ann = self.annotation[index]
482
+ # print("loading audio::", ann)
483
+ shift_allowed = "pitch" not in ann.get("task", "")
484
+ noise_allowed = (
485
+ "/A/" not in ann.get("path", "")
486
+ and "-qa" not in ann.get("task", "")
487
+ and "icl" not in ann.get("task", "")
488
+ and "caption" not in ann.get("task", "")
489
+ and "animal-instructions" not in ann.get("task", "")
490
+ )
491
+
492
+ task = ann.get("task", "asr")
493
+ text = ann["text"]
494
+ prompt = ann["prompt"]
495
+
496
+ replace_with_noise = (
497
+ self.use_augmentation
498
+ and task.endswith("detection")
499
+ and self.rng.random() < self.mask_audio_prob
500
+ and len(self.noise_files) > 0
501
+ )
502
+
503
+ if replace_with_noise:
504
+ # Replace audio with noise
505
+ audio = self._load_noise(shift_allowed)
506
+ audios = [audio]
507
+ text = "None"
508
+
509
+ else:
510
+ if "path" in ann and ann["path"] is not None:
511
+ audio = self.load_audio(ann["path"], shift_allowed, noise_allowed)
512
+ audios = [audio]
513
+ else:
514
+ audios = [self.load_audio(p, shift_allowed, noise_allowed) for p in ann["files"]]
515
+
516
+ if len(audios) == 1:
517
+ prompt, mixed_audio, text = self._apply_mixup(prompt, audio, text, task, filename=ann["path"])
518
+ audios = [mixed_audio]
519
+
520
+ return {
521
+ "raw_wav": audios,
522
+ "text": text,
523
+ "task": task,
524
+ "id": ann.get("path") or ";".join(ann["files"]),
525
+ "prompt": prompt,
526
+ "index": index, # track which element for eval output
527
+ "ann": ann, # Include annotation for mixup
528
+ }
529
+
530
+
531
+ if __name__ == "__main__":
532
+ dataset = NatureLMDataset(
533
+ ann_path="/home/ubuntu/foundation-model-storage/foundation-model-data/data/compiled-datasets/v1/s2_eval_valid.jsonl",
534
+ noise_dirs=["resource/audio_demo"],
535
+ max_length_seconds=10,
536
+ use_augmentation=True,
537
+ mixup_prob=1.0, # For demonstration, force mixup if possible
538
+ mixup_count=2, # Up to 2 mixup partners
539
+ mask_audio_prob=0.2,
540
+ seed=42,
541
+ noise_prob=0.5,
542
+ )
543
+
544
+ # Process just a few to see the saved mixups
545
+ for i in range(300):
546
+ sample = dataset[i]
547
+ # print("Final text:", sample["text"])
548
+ # print("Final prompt:", sample["prompt"])
549
+ # print("-" * 40)
550
+ print("Done! Look in 'debug_outputs' folder for saved mixup files.")
NatureLM/dist_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from salesforce@LAVIS. Below is the original copyright:
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import datetime
10
+ import functools
11
+ import os
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.use_distributed = False
68
+ return
69
+
70
+ args.use_distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ print(
74
+ "| distributed init (rank {}, world {}): {}".format(args.rank, args.world_size, args.dist_url),
75
+ flush=True,
76
+ )
77
+ torch.distributed.init_process_group(
78
+ backend=args.dist_backend,
79
+ init_method=args.dist_url,
80
+ world_size=args.world_size,
81
+ rank=args.rank,
82
+ timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing
83
+ )
84
+ torch.distributed.barrier()
85
+ setup_for_distributed(args.rank == 0)
86
+
87
+
88
+ def get_dist_info():
89
+ if torch.__version__ < "1.0":
90
+ initialized = dist._initialized
91
+ else:
92
+ initialized = dist.is_initialized()
93
+ if initialized:
94
+ rank = dist.get_rank()
95
+ world_size = dist.get_world_size()
96
+ else: # non-distributed training
97
+ rank = 0
98
+ world_size = 1
99
+ return rank, world_size
100
+
101
+
102
+ def main_process(func):
103
+ @functools.wraps(func)
104
+ def wrapper(*args, **kwargs):
105
+ rank, _ = get_dist_info()
106
+ if rank == 0:
107
+ return func(*args, **kwargs)
108
+
109
+ return wrapper
NatureLM/infer.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run NatureLM-audio over a set of audio files paths or a directory with audio files."""
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import soundfile as sf
9
+ import torch
10
+
11
+ from NatureLM.config import Config
12
+ from NatureLM.models import NatureLM
13
+ from NatureLM.processors import NatureLMAudioProcessor
14
+ from NatureLM.utils import move_to_device
15
+
16
+ _MAX_LENGTH_SECONDS = 10
17
+ _MIN_CHUNK_LENGTH_SECONDS = 0.5
18
+ _SAMPLE_RATE = 16000 # Assuming the model uses a sample rate of 16kHz
19
+ _AUDIO_FILE_EXTENSIONS = [".wav", ".mp3", ".flac", ".ogg"] # Add other audio file formats as needed
20
+ _DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
21
+ __this_dir = Path(__file__).parent.parent
22
+ _DEFAULT_CONFIG_PATH = __this_dir / "configs" / "inference.yml"
23
+
24
+
25
+ def load_model_and_config(
26
+ cfg_path: str | Path = _DEFAULT_CONFIG_PATH, device: str = _DEVICE
27
+ ) -> tuple[NatureLM, Config]:
28
+ """Load the NatureLM model and configuration.
29
+ Returns:
30
+ tuple: The loaded model and configuration.
31
+ """
32
+ model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
33
+ model = model.to(device).eval()
34
+ model.llama_tokenizer.pad_token_id = model.llama_tokenizer.eos_token_id
35
+ model.llama_model.generation_config.pad_token_id = model.llama_tokenizer.pad_token_id
36
+
37
+ cfg = Config.from_sources(cfg_path)
38
+ return model, cfg
39
+
40
+
41
+ def output_template(model_output: str, start_time: float, end_time: float) -> str:
42
+ """Format the output of the model."""
43
+ return f"#{start_time:.2f}s - {end_time:.2f}s#: {model_output}\n"
44
+
45
+
46
+ def sliding_window_inference(
47
+ audio: str | Path | np.ndarray,
48
+ query: str,
49
+ processor: NatureLMAudioProcessor,
50
+ model: NatureLM,
51
+ cfg: Config,
52
+ window_length_seconds: float = 10.0,
53
+ hop_length_seconds: float = 10.0,
54
+ input_sr: int = _SAMPLE_RATE,
55
+ device: str = _DEVICE,
56
+ ) -> str:
57
+ """Run inference on a long audio file using sliding window approach.
58
+
59
+ Args:
60
+ audio (str | Path | np.ndarray): Path to the audio file.
61
+ query (str): Query for the model.
62
+ processor (NatureLMAudioProcessor): Audio processor.
63
+ model (NatureLM): NatureLM model.
64
+ cfg (Config): Model configuration.
65
+ window_length_seconds (float): Length of the sliding window in seconds.
66
+ hop_length_seconds (float): Hop length for the sliding window in seconds.
67
+ input_sr (int): Sample rate of the audio file.
68
+
69
+ Returns:
70
+ str: The output of the model.
71
+
72
+ Raises:
73
+ ValueError: If the audio file is too short or if the audio file path is invalid.
74
+ """
75
+ if isinstance(audio, str) or isinstance(audio, Path):
76
+ audio_array, input_sr = sf.read(str(audio))
77
+ elif isinstance(audio, np.ndarray):
78
+ audio_array = audio
79
+ print(f"Using provided sample rate: {input_sr}")
80
+
81
+ audio_array = audio_array.squeeze()
82
+ if audio_array.ndim > 1:
83
+ axis_to_average = int(np.argmin(audio_array.shape))
84
+ audio_array = audio_array.mean(axis=axis_to_average)
85
+ audio_array = audio_array.squeeze()
86
+
87
+ # Do initial check that the audio is long enough
88
+ if audio_array.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
89
+ raise ValueError(f"Audio is too short. Minimum length is {_MIN_CHUNK_LENGTH_SECONDS} seconds.")
90
+
91
+ start = 0
92
+ stride = int(hop_length_seconds * input_sr)
93
+ window_length = int(window_length_seconds * input_sr)
94
+
95
+ output = ""
96
+ while True:
97
+ chunk = audio_array[start : start + window_length]
98
+ if chunk.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
99
+ break
100
+
101
+ # Resamples, pads, truncates and creates torch Tensor
102
+ audio_tensor, prompt_list = processor([chunk], [query], [input_sr])
103
+
104
+ input_to_model = {
105
+ "raw_wav": audio_tensor,
106
+ "prompt": prompt_list[0],
107
+ "audio_chunk_sizes": 1,
108
+ "padding_mask": torch.zeros_like(audio_tensor).to(torch.bool),
109
+ }
110
+ input_to_model = move_to_device(input_to_model, device)
111
+
112
+ # generate
113
+ prediction: str = model.generate(input_to_model, cfg.generate, prompt_list)[0]
114
+
115
+ # Post-process the prediction
116
+ prediction = output_template(prediction, start / input_sr, (start + window_length) / input_sr)
117
+ output += prediction
118
+
119
+ # Move the window
120
+ start += stride
121
+
122
+ if start + window_length > audio_array.shape[-1]:
123
+ break
124
+
125
+ return output
126
+
127
+
128
+ class Pipeline:
129
+ """Pipeline for running NatureLM-audio inference on a list of audio files or audio arrays"""
130
+
131
+ def __init__(self, model: NatureLM = None, cfg_path: str | Path = _DEFAULT_CONFIG_PATH):
132
+ self.cfg_path = cfg_path
133
+
134
+ # Load model and config
135
+ if model is not None:
136
+ self.cfg = Config.from_sources(cfg_path)
137
+ self.model = model
138
+ else:
139
+ # Download model from hub
140
+ self.model, self.cfg = load_model_and_config(cfg_path)
141
+
142
+ self.processor = NatureLMAudioProcessor(sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS)
143
+
144
+ def __call__(
145
+ self,
146
+ audios: list[str | Path | np.ndarray],
147
+ queries: str | list[str],
148
+ window_length_seconds: float = 10.0,
149
+ hop_length_seconds: float = 10.0,
150
+ input_sample_rate: int = _SAMPLE_RATE,
151
+ verbose: bool = False,
152
+ ) -> list[str]:
153
+ """Run inference on a list of audio file paths or a single audio file with a
154
+ single query or a list of queries. If multiple queries are provided,
155
+ we assume that they are in the same order as the audio files. If a single query
156
+ is provided, it will be used for all audio files.
157
+
158
+ Args:
159
+ audios (list[str | Path | np.ndarray]): List of audio file paths or a single audio file path or audio array(s)
160
+ queries (str | list[str]): Queries for the model.
161
+ window_length_seconds (float): Length of the sliding window in seconds. Defaults to 10.0.
162
+ hop_length_seconds (float): Hop length for the sliding window in seconds. Defaults to 10.0.
163
+ input_sample_rate (int): Sample rate of the audio. Defaults to 16000, which is the model's sample rate.
164
+ verbose (bool): If True, print the output of the model for each audio file.
165
+ Defaults to False.
166
+
167
+ Returns:
168
+ str | list[str]: The output of the model..
169
+
170
+ Raises:
171
+ ValueError: If the number of audio files and queries do not match.
172
+
173
+ Example:
174
+ >>> pipeline = Pipeline()
175
+ >>> audios = ["assets/nri-GreenTreeFrogEvergladesNP.mp3"]
176
+ >>> queries = ["Which species is this? Provide the common name."]
177
+ >>> results = pipeline(audios, queries)
178
+ >>> print(results)
179
+ ['#0.00s - 10.00s#: Green Treefrog\n']
180
+ """
181
+ if isinstance(audios, str) or isinstance(audios, Path):
182
+ audios = [audios]
183
+
184
+ if isinstance(queries, str):
185
+ queries = [queries] * len(audios)
186
+
187
+ if len(audios) != len(queries):
188
+ raise ValueError("Number of audio files and queries must match.")
189
+
190
+ # Run inference
191
+ results = []
192
+ for audio, query in zip(audios, queries):
193
+ output = sliding_window_inference(
194
+ audio,
195
+ query,
196
+ self.processor,
197
+ self.model,
198
+ self.cfg,
199
+ window_length_seconds,
200
+ hop_length_seconds,
201
+ input_sr=input_sample_rate,
202
+ )
203
+ results.append(output)
204
+ if verbose:
205
+ print(f"Processed {audio}, model output:\n=======\n{output}\n=======")
206
+ return results
207
+
208
+
209
+ def parse_args() -> argparse.Namespace:
210
+ parser = argparse.ArgumentParser("Run NatureLM-audio inference")
211
+ parser.add_argument(
212
+ "-a", "--audio", type=str, required=True, help="Path to an audio file or a directory containing audio files"
213
+ )
214
+ parser.add_argument("-q", "--query", type=str, required=True, help="Query for the model")
215
+ parser.add_argument(
216
+ "--cfg-path",
217
+ type=str,
218
+ default="configs/inference.yml",
219
+ help="Path to the configuration file for the model",
220
+ )
221
+ parser.add_argument("--output_path", type=str, default="inference_output.jsonl", help="Output path for the results")
222
+ parser.add_argument(
223
+ "--window_length_seconds", type=float, default=10.0, help="Length of the sliding window in seconds"
224
+ )
225
+ parser.add_argument(
226
+ "--hop_length_seconds", type=float, default=10.0, help="Hop length for the sliding window in seconds"
227
+ )
228
+ args = parser.parse_args()
229
+
230
+ return args
231
+
232
+
233
+ def main(
234
+ cfg_path: str | Path,
235
+ audio_path: str | Path,
236
+ query: str,
237
+ output_path: str,
238
+ window_length_seconds: float,
239
+ hop_length_seconds: float,
240
+ ) -> None:
241
+ """Main function to run the NatureLM-audio inference script.
242
+ It takes command line arguments for audio file path, query, output path,
243
+ window length, and hop length. It processes the audio files and saves the
244
+ results to a CSV file.
245
+
246
+ Args:
247
+ cfg_path (str | Path): Path to the configuration file.
248
+ audio_path (str | Path): Path to the audio file or directory.
249
+ query (str): Query for the model.
250
+ output_path (str): Path to save the output results.
251
+ window_length_seconds (float): Length of the sliding window in seconds.
252
+ hop_length_seconds (float): Hop length for the sliding window in seconds.
253
+
254
+ Raises:
255
+ ValueError: If the audio file path is invalid or if the query is empty.
256
+ ValueError: If no audio files are found.
257
+ ValueError: If the audio file extension is not supported.
258
+ """
259
+
260
+ # Prepare sample
261
+ audio_path = Path(audio_path)
262
+ if audio_path.is_dir():
263
+ audio_paths = []
264
+ print(f"Searching for audio files in {str(audio_path)} with extensions {', '.join(_AUDIO_FILE_EXTENSIONS)}")
265
+ for ext in _AUDIO_FILE_EXTENSIONS:
266
+ audio_paths.extend(list(audio_path.rglob(f"*{ext}")))
267
+
268
+ print(f"Found {len(audio_paths)} audio files in {str(audio_path)}")
269
+ else:
270
+ # check that the extension is valid
271
+ if not any(audio_path.suffix == ext for ext in _AUDIO_FILE_EXTENSIONS):
272
+ raise ValueError(
273
+ f"Invalid audio file extension. Supported extensions are: {', '.join(_AUDIO_FILE_EXTENSIONS)}"
274
+ )
275
+ audio_paths = [audio_path]
276
+
277
+ # check that query is not empty
278
+ if not query:
279
+ raise ValueError("Query cannot be empty")
280
+ if not audio_paths:
281
+ raise ValueError("No audio files found. Please check the path or file extensions.")
282
+
283
+ # Load model and config
284
+ model, cfg = load_model_and_config(cfg_path)
285
+
286
+ # Load audio processor
287
+ processor = NatureLMAudioProcessor(sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS)
288
+
289
+ # Run inference
290
+ results = {"audio_path": [], "output": []}
291
+ for path in audio_paths:
292
+ output = sliding_window_inference(path, query, processor, model, cfg, window_length_seconds, hop_length_seconds)
293
+ results["audio_path"].append(str(path))
294
+ results["output"].append(output)
295
+ print(f"Processed {path}, model output:\n=======\n{output}\n=======\n")
296
+
297
+ # Save results as a csv
298
+ output_path = Path(output_path)
299
+ output_path.parent.mkdir(parents=True, exist_ok=True)
300
+
301
+ df = pd.DataFrame(results)
302
+ df.to_json(output_path, orient="records", lines=True)
303
+ print(f"Results saved to {output_path}")
304
+
305
+
306
+ if __name__ == "__main__":
307
+ args = parse_args()
308
+ main(
309
+ cfg_path=args.cfg_path,
310
+ audio_path=args.audio,
311
+ query=args.query,
312
+ output_path=args.output_path,
313
+ window_length_seconds=args.window_length_seconds,
314
+ hop_length_seconds=args.hop_length_seconds,
315
+ )
NatureLM/logger.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+ from collections import defaultdict, deque
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import wandb
9
+
10
+ from NatureLM.dist_utils import is_dist_avail_and_initialized, is_main_process
11
+
12
+
13
+ class SmoothedValue(object):
14
+ """Track a series of values and provide access to smoothed values over a
15
+ window or the global series average.
16
+ """
17
+
18
+ def __init__(self, window_size=20, fmt=None):
19
+ if fmt is None:
20
+ fmt = "{median:.4f} ({global_avg:.4f})"
21
+ self.deque = deque(maxlen=window_size)
22
+ self.total = 0.0
23
+ self.count = 0
24
+ self.fmt = fmt
25
+
26
+ def update(self, value, n=1):
27
+ self.deque.append(value)
28
+ self.count += n
29
+ self.total += value * n
30
+
31
+ def synchronize_between_processes(self):
32
+ """
33
+ Warning: does not synchronize the deque!
34
+ """
35
+ if not is_dist_avail_and_initialized():
36
+ return
37
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
38
+ dist.barrier()
39
+ dist.all_reduce(t)
40
+ t = t.tolist()
41
+ self.count = int(t[0])
42
+ self.total = t[1]
43
+
44
+ @property
45
+ def median(self):
46
+ d = torch.tensor(list(self.deque))
47
+ return d.median().item()
48
+
49
+ @property
50
+ def avg(self):
51
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
52
+ return d.mean().item()
53
+
54
+ @property
55
+ def global_avg(self):
56
+ return self.total / self.count
57
+
58
+ @property
59
+ def max(self):
60
+ return max(self.deque)
61
+
62
+ @property
63
+ def value(self):
64
+ return self.deque[-1]
65
+
66
+ def __str__(self):
67
+ return self.fmt.format(
68
+ median=self.median,
69
+ avg=self.avg,
70
+ global_avg=self.global_avg,
71
+ max=self.max,
72
+ value=self.value,
73
+ )
74
+
75
+
76
+ class MetricLogger(object):
77
+ def __init__(self, delimiter="\t"):
78
+ self.meters = defaultdict(SmoothedValue)
79
+ self.delimiter = delimiter
80
+
81
+ def update(self, **kwargs):
82
+ for k, v in kwargs.items():
83
+ if isinstance(v, torch.Tensor):
84
+ v = v.item()
85
+ assert isinstance(v, (float, int))
86
+ self.meters[k].update(v)
87
+
88
+ def __getattr__(self, attr):
89
+ if attr in self.meters:
90
+ return self.meters[attr]
91
+ if attr in self.__dict__:
92
+ return self.__dict__[attr]
93
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
94
+
95
+ def __str__(self):
96
+ loss_str = []
97
+ for name, meter in self.meters.items():
98
+ loss_str.append("{}: {}".format(name, str(meter)))
99
+ return self.delimiter.join(loss_str)
100
+
101
+ def global_avg(self):
102
+ loss_str = []
103
+ for name, meter in self.meters.items():
104
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
105
+ return self.delimiter.join(loss_str)
106
+
107
+ def synchronize_between_processes(self):
108
+ for meter in self.meters.values():
109
+ meter.synchronize_between_processes()
110
+
111
+ def add_meter(self, name, meter):
112
+ self.meters[name] = meter
113
+
114
+ def log_every(self, iterable, print_freq, header=None, logger=None, start_step=None):
115
+ i = 0
116
+ if not header:
117
+ header = ""
118
+ start_time = time.time()
119
+ end = time.time()
120
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
121
+ data_time = SmoothedValue(fmt="{avg:.4f}")
122
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
123
+ log_msg = [
124
+ header,
125
+ "[{0" + space_fmt + "}/{1}]",
126
+ "eta: {eta}",
127
+ "{meters}",
128
+ "time: {time}",
129
+ "data: {data}",
130
+ ]
131
+ if torch.cuda.is_available():
132
+ log_msg.append("max mem: {memory:.0f}")
133
+ log_msg = self.delimiter.join(log_msg)
134
+ MB = 1024.0 * 1024.0
135
+ for obj in iterable:
136
+ data_time.update(time.time() - end)
137
+ yield obj
138
+ iter_time.update(time.time() - end)
139
+ if i % print_freq == 0 or i == len(iterable) - 1:
140
+ if is_main_process():
141
+ if logger is not None:
142
+ assert start_step is not None, "start_step is needed to compute global_step!"
143
+ for name, meter in self.meters.items():
144
+ logger.add_scalar("{}".format(name), float(str(meter)), global_step=start_step + i)
145
+ # Log to wandb
146
+ wandb.log({name: float(str(meter)) for name, meter in self.meters.items()}, step=start_step + i)
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ print(
151
+ log_msg.format(
152
+ i,
153
+ len(iterable),
154
+ eta=eta_string,
155
+ meters=str(self),
156
+ time=str(iter_time),
157
+ data=str(data_time),
158
+ memory=torch.cuda.max_memory_allocated() / MB,
159
+ )
160
+ )
161
+ else:
162
+ print(
163
+ log_msg.format(
164
+ i,
165
+ len(iterable),
166
+ eta=eta_string,
167
+ meters=str(self),
168
+ time=str(iter_time),
169
+ data=str(data_time),
170
+ )
171
+ )
172
+ i += 1
173
+ end = time.time()
174
+ total_time = time.time() - start_time
175
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
176
+ print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
177
+
178
+
179
+ class AttrDict(dict):
180
+ def __init__(self, *args, **kwargs):
181
+ super(AttrDict, self).__init__(*args, **kwargs)
182
+ self.__dict__ = self
183
+
184
+
185
+ def setup_logger():
186
+ logging.basicConfig(
187
+ level=logging.INFO if is_main_process() else logging.WARN,
188
+ format="%(asctime)s [%(levelname)s] %(message)s",
189
+ handlers=[logging.StreamHandler()],
190
+ )
NatureLM/models/NatureLM.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Earth Species Project
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ from pathlib import Path
18
+ from typing import Literal, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from huggingface_hub import PyTorchModelHubMixin
24
+ from peft import LoraConfig, TaskType, get_peft_model
25
+ from torch.nn import CrossEntropyLoss
26
+ from torch.nn.utils.rnn import pad_sequence
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
28
+
29
+ from NatureLM.checkpoint_utils import save_model_checkpoint
30
+ from NatureLM.config import BeatsConfig, ModelConfig, save_config_as_yaml
31
+ from NatureLM.utils import universal_torch_load
32
+
33
+ from .beats.BEATs import BEATs, BEATsConfig
34
+ from .Qformer import BertConfig, BertLMHeadModel
35
+ from .utils import StoppingCriteriaSub
36
+
37
+ torch.backends.cuda.matmul.allow_tf32 = True
38
+
39
+ auth_token = os.getenv('llama')
40
+
41
+ class NatureLM(nn.Module, PyTorchModelHubMixin):
42
+ def __init__(
43
+ self,
44
+ *,
45
+ llama_path: Path,
46
+ beats_path: Path | os.PathLike | None = None,
47
+ beats_cfg: BeatsConfig,
48
+ freeze_beats: bool = True,
49
+ use_audio_Qformer: bool = True,
50
+ max_pooling: bool = False,
51
+ num_audio_query_token: int = 1,
52
+ freeze_audio_QFormer: bool = False,
53
+ window_level_Qformer: bool = True,
54
+ second_per_window: float = 0.333333,
55
+ second_stride: float = 0.333333,
56
+ downsample_factor: int = 4,
57
+ audio_llama_proj_model: Path | os.PathLike | None = None,
58
+ freeze_audio_llama_proj: bool = False,
59
+ lora: bool = True,
60
+ lora_rank: int = 8,
61
+ lora_alpha: int = 32,
62
+ lora_dropout: float = 0.1,
63
+ flash_attn: Literal["eager", "flash_attention_2"] = "eager",
64
+ prompt_template: str = "",
65
+ max_txt_len: int = 128,
66
+ end_sym: str = "</s>",
67
+ device: str = "cuda",
68
+ ):
69
+ super().__init__()
70
+
71
+ self.beats_path = beats_path
72
+ self.beats_cfg = beats_cfg
73
+ self.use_audio_Qformer = use_audio_Qformer
74
+ self.max_pooling = max_pooling
75
+ self.window_level_Qformer = window_level_Qformer
76
+ self.second_per_window = second_per_window
77
+ self.second_stride = second_stride
78
+ self.downsample_factor = downsample_factor
79
+ self.lora = lora
80
+ self.max_txt_len = max_txt_len
81
+ self.end_sym = end_sym
82
+ self.prompt_template = prompt_template
83
+ self.flash_attn = flash_attn
84
+
85
+ logging.info(f"Llama path: {llama_path}")
86
+ logging.info("Loading Llama Tokenizer")
87
+ self.llama_tokenizer = AutoTokenizer.from_pretrained(llama_path, use_fast=False, use_auth_token=auth_token)
88
+ self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
89
+ self.llama_tokenizer.padding_side = "right"
90
+
91
+ logging.info("Loading Llama Model")
92
+ if device == "cpu":
93
+ self.llama_model = AutoModelForCausalLM.from_pretrained(
94
+ llama_path,
95
+ torch_dtype=torch.float32,
96
+ attn_implementation="eager",
97
+ device_map="cpu",
98
+ use_auth_token=auth_token
99
+ )
100
+ # An issue with tiny-llama is that pad_token_id was set to -1, but
101
+ # model.save_pretrained checks generation configs and does not allow -1 as
102
+ # pad_token_id
103
+ self.llama_model.generation_config.pad_token_id = self.llama_tokenizer.pad_token_id
104
+ else:
105
+ self.llama_model = AutoModelForCausalLM.from_pretrained(
106
+ llama_path,
107
+ torch_dtype=torch.bfloat16,
108
+ attn_implementation=flash_attn,
109
+ use_auth_token=auth_token
110
+ )
111
+
112
+ self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
113
+ if self.lora:
114
+ for param in self.llama_model.parameters():
115
+ param.requires_grad = False
116
+ logging.info("Loading LLaMA Done")
117
+ self.llama_embed_tokens = self.llama_model.model.embed_tokens
118
+
119
+ if self.lora:
120
+ logging.info("Setting up LoRA for llama model")
121
+ self.peft_config = LoraConfig(
122
+ task_type=TaskType.CAUSAL_LM,
123
+ inference_mode=False,
124
+ r=lora_rank,
125
+ lora_alpha=lora_alpha,
126
+ lora_dropout=lora_dropout,
127
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
128
+ )
129
+ self.llama_model = get_peft_model(self.llama_model, self.peft_config)
130
+ self.llama_embed_tokens = self.llama_model.model.model.embed_tokens
131
+ self.llama_model.print_trainable_parameters()
132
+ logging.info("LoRA Training")
133
+
134
+ logging.info("Loading BEATs Model")
135
+ self.beats = BEATs(cfg=BEATsConfig(dict(self.beats_cfg)))
136
+
137
+ if self.beats_path:
138
+ beats_ckpt = universal_torch_load(self.beats_path, cache_mode="none", map_location="cpu")
139
+ self.beats.load_state_dict(beats_ckpt["model"])
140
+
141
+ self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
142
+ if freeze_beats:
143
+ for param in self.beats.parameters():
144
+ param.requires_grad = False
145
+ self.beats.eval()
146
+ logging.info("freeze BEATs")
147
+
148
+ if self.use_audio_Qformer:
149
+ self.audio_Qformer, self.audio_query_tokens = self.init_audio_Qformer(
150
+ num_query_token=num_audio_query_token,
151
+ audio_width=self.beats.cfg.encoder_embed_dim,
152
+ )
153
+
154
+ self.audio_Qformer.bert.embeddings.word_embeddings = None
155
+ self.audio_Qformer.bert.embeddings.position_embeddings = None
156
+ for layer in self.audio_Qformer.bert.encoder.layer:
157
+ layer.output = None
158
+ layer.intermediate = None
159
+ self.audio_Qformer.cls = None
160
+ if freeze_audio_QFormer:
161
+ for param in self.audio_Qformer.parameters():
162
+ param.requires_grad = False
163
+ self.audio_Qformer.eval()
164
+ self.audio_query_tokens.requires_grad = False
165
+ logging.info("freeze audio QFormer")
166
+
167
+ logging.info("Loading audio LLAMA proj")
168
+ self.audio_llama_proj = nn.Linear(
169
+ self.audio_Qformer.config.hidden_size,
170
+ self.llama_model.config.hidden_size,
171
+ )
172
+ if audio_llama_proj_model:
173
+ logging.info(f"Loading audio LLAMA proj from {audio_llama_proj_model}")
174
+ # audio_llama_proj_weight = torch.load(audio_llama_proj_model, map_location="cpu")
175
+ audio_llama_proj_weight = universal_torch_load(
176
+ audio_llama_proj_model, cache_mode="use", map_location="cpu"
177
+ )
178
+ self.load_state_dict(audio_llama_proj_weight["model"], strict=False)
179
+
180
+ if freeze_audio_llama_proj:
181
+ for param in self.audio_llama_proj.parameters():
182
+ param.requires_grad = False
183
+ self.audio_llama_proj.eval()
184
+ logging.info("freeze audio LLAMA proj")
185
+
186
+ elif self.max_pooling:
187
+ hidden_size = (
188
+ 768
189
+ if self.aves
190
+ else 768
191
+ if self.htsat
192
+ else 1024
193
+ if self.aves_large
194
+ else self.beats.cfg.encoder_embed_dim
195
+ )
196
+ self.audio_llama_proj = nn.Linear(
197
+ hidden_size, self.llama_model.config.hidden_size
198
+ ) # Single embedding, just project to LLM.
199
+
200
+ elif self.htsat:
201
+ self.audio_llama_proj = nn.Linear(
202
+ 512, self.llama_model.config.hidden_size
203
+ ) # Single embedding, just project to LLM.
204
+
205
+ else:
206
+ # feel free to add other aligners here
207
+ raise NotImplementedError("Have to use audio qformer")
208
+
209
+ self.config: ModelConfig = None # set this in from_config
210
+
211
+ @classmethod
212
+ def from_config(cls, config: ModelConfig):
213
+ model = cls(
214
+ llama_path=config.llama_path,
215
+ beats_path=config.beats_path,
216
+ freeze_beats=config.freeze_beats,
217
+ use_audio_Qformer=config.use_audio_Qformer,
218
+ max_pooling=config.max_pooling,
219
+ num_audio_query_token=config.num_audio_query_token,
220
+ freeze_audio_QFormer=config.freeze_audio_QFormer,
221
+ window_level_Qformer=config.window_level_Qformer,
222
+ second_per_window=config.second_per_window,
223
+ second_stride=config.second_stride,
224
+ downsample_factor=config.downsample_factor,
225
+ audio_llama_proj_model=config.audio_llama_proj_model,
226
+ freeze_audio_llama_proj=config.freeze_audio_llama_proj,
227
+ lora=config.lora,
228
+ lora_rank=config.lora_rank,
229
+ lora_alpha=config.lora_alpha,
230
+ lora_dropout=config.lora_dropout,
231
+ prompt_template=config.prompt_template,
232
+ max_txt_len=config.max_txt_len,
233
+ end_sym=config.end_sym,
234
+ flash_attn=config.flash_attn,
235
+ device=config.device,
236
+ )
237
+ model.config = config
238
+ ckpt_path = config.ckpt
239
+ if ckpt_path:
240
+ logging.info(f"⏳ Load NatureLM ckpt from: {ckpt_path}")
241
+ ckpt = universal_torch_load(ckpt_path, cache_mode="use", map_location="cpu")
242
+ model.load_state_dict(ckpt["model"], strict=False)
243
+ logging.info("✅ Finished loading from ckpt")
244
+
245
+ return model
246
+
247
+ def _save_to_local(
248
+ self,
249
+ output_dir: Union[str, os.PathLike],
250
+ use_distributed: bool = False,
251
+ drop_untrained_params: bool = False,
252
+ ) -> None:
253
+ output_dir = Path(output_dir)
254
+ output_dir.mkdir(parents=True, exist_ok=True)
255
+
256
+ # Save the config
257
+ config_path = output_dir / "model_config.yaml"
258
+ save_config_as_yaml(self.config, config_path)
259
+
260
+ # Save the model
261
+ model_path = output_dir / "model.pt"
262
+ save_model_checkpoint(
263
+ self,
264
+ model_path,
265
+ drop_untrained_params=drop_untrained_params,
266
+ use_distributed=use_distributed,
267
+ )
268
+
269
+ # Save the tokenizer and llama model
270
+ tokenizer_path = output_dir / "llama"
271
+ self.llama_tokenizer.save_pretrained(tokenizer_path)
272
+ self.llama_model.save_pretrained(tokenizer_path)
273
+
274
+ # Save the audio model
275
+ if self.beats_path:
276
+ beats_path = output_dir / "beats.pt"
277
+ save_model_checkpoint(
278
+ self.beats,
279
+ beats_path,
280
+ drop_untrained_params=drop_untrained_params,
281
+ cfg=self.beats_cfg,
282
+ )
283
+
284
+ # Save the audio projection
285
+ audio_llama_proj_path = output_dir / "audio_llama_proj.pt"
286
+ save_model_checkpoint(
287
+ self.audio_llama_proj,
288
+ audio_llama_proj_path,
289
+ drop_untrained_params=drop_untrained_params,
290
+ )
291
+
292
+ @staticmethod
293
+ def init_audio_Qformer(num_query_token, audio_width, num_hidden_layers=2):
294
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
295
+ encoder_config.num_hidden_layers = num_hidden_layers
296
+ encoder_config.encoder_width = audio_width
297
+ # insert cross-attention layer every other block
298
+ encoder_config.add_cross_attention = True
299
+ encoder_config.cross_attention_freq = 1
300
+ encoder_config.query_length = num_query_token
301
+ Qformer = BertLMHeadModel(config=encoder_config)
302
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
303
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
304
+ return Qformer, query_tokens
305
+
306
+ @property
307
+ def device(self):
308
+ return list(self.parameters())[0].device
309
+
310
+ def _encode_auditory_feature(self, audio_embeds, audio_pad_mask):
311
+ if self.max_pooling:
312
+ # Max Pooling logic to reduce sequence length
313
+
314
+ # Apply 1D Max Pooling along the time dimension
315
+ audio_embeds = F.max_pool1d(
316
+ audio_embeds.transpose(1, 2),
317
+ kernel_size=self.downsample_factor,
318
+ stride=self.downsample_factor,
319
+ ).transpose(1, 2)
320
+ audio_embeds = self.audio_llama_proj(audio_embeds)
321
+
322
+ # print("audio pad mask is", audio_pad_mask)
323
+ audio_atts = ~audio_pad_mask
324
+ # Adjust the padding mask using max pooling
325
+ audio_atts = F.max_pool1d(
326
+ audio_atts.unsqueeze(1).float(),
327
+ kernel_size=self.downsample_factor,
328
+ stride=self.downsample_factor,
329
+ ).squeeze(1)
330
+ audio_atts = audio_atts > 0
331
+ # print(f"audio pad mask shape after pooling: {audio_atts.shape}")
332
+ # print("audio pad mask post", audio_atts)
333
+
334
+ elif self.use_audio_Qformer:
335
+ # Q-Former logic
336
+ audio_embeds = self.ln_audio(audio_embeds)
337
+
338
+ # Generate attention mask
339
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
340
+
341
+ if self.window_level_Qformer:
342
+ B, T, C = audio_embeds.shape # batch, T, Channels
343
+ kernel = round(1500 * self.second_per_window / 30.0) # 160 ms patches; calculate kernel size
344
+ stride = round(1500 * self.second_stride / 30.0) # Calculate stride size
345
+ kernel = (1, kernel)
346
+ stride = (1, stride)
347
+
348
+ # Transpose and unfold audio embeddings to create overlapping windows
349
+ audio_embeds_tr = audio_embeds.transpose(1, 2).unsqueeze(2)
350
+ audio_embeds_overlap = F.unfold(
351
+ audio_embeds_tr,
352
+ kernel_size=kernel,
353
+ dilation=1,
354
+ padding=0,
355
+ stride=stride,
356
+ )
357
+ _, _, L = audio_embeds_overlap.shape
358
+ audio_embeds_overlap = audio_embeds_overlap.view(B, -1, kernel[1], L)
359
+ audio_embeds_overlap = torch.permute(
360
+ audio_embeds_overlap, [0, 3, 2, 1]
361
+ ) # (B, num_windows, kernel_size, C)
362
+ audio_embeds = audio_embeds_overlap.reshape(-1, kernel[1], C)
363
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
364
+
365
+ # Q-Former mechanism
366
+ query_tokens = self.audio_query_tokens.expand(audio_embeds.shape[0], -1, -1)
367
+ query_output = self.audio_Qformer.bert(
368
+ query_embeds=query_tokens,
369
+ encoder_hidden_states=audio_embeds,
370
+ encoder_attention_mask=audio_atts,
371
+ return_dict=True,
372
+ )
373
+
374
+ audio_embeds = self.audio_llama_proj(query_output.last_hidden_state)
375
+
376
+ if self.window_level_Qformer:
377
+ audio_embeds = audio_embeds.view(B, -1, audio_embeds.size(2)).contiguous()
378
+
379
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
380
+
381
+ elif self.htsat:
382
+ # HTSAT processing
383
+ audio_embeds = self.ln_audio(audio_embeds)
384
+ audio_embeds = self.audio_llama_proj(audio_embeds).reshape(-1, 30, self.llama_model.config.hidden_size)
385
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
386
+
387
+ else:
388
+ raise NotImplementedError("no audio qformer or max pooling")
389
+
390
+ return audio_embeds, audio_atts
391
+
392
+ def encode_audio(self, raw_wav, audio_padding_mask=None):
393
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
394
+ audio_embeds, audio_pad_mask = self.beats(raw_wav, padding_mask=audio_padding_mask)
395
+ return self._encode_auditory_feature(audio_embeds=audio_embeds, audio_pad_mask=audio_pad_mask)
396
+
397
+ def prompt_wrap(self, audio_embeds, audio_atts, prompt: list[str]):
398
+ """Merge audio embeddings with embeddings of the tokens in the prompt.
399
+
400
+ Args:
401
+ audio_embeds (list): List of tensors of audio embeddings.
402
+ audio_atts (list): List of tensors of audio padding masks.
403
+ prompt (list): List of strings with the prompt for each sample. Each prompt
404
+ should contain the placeholder(s) "<AudioHere>" to indicate where the
405
+ audio embeddings should be inserted.
406
+
407
+ Returns:
408
+ tuple: A tuple containing the wrapped audio embeddings and padding masks.
409
+ """
410
+
411
+ def interleave_lists(longer: list, shorter: list) -> list:
412
+ """Interleave two lists where the first list is one element longer.
413
+
414
+ Args:
415
+ longer (list): The first list with length n.
416
+ shorter (list): The second list with length n-1.
417
+
418
+ Returns:
419
+ list: A new list with elements interleaved from longer and shorter.
420
+
421
+ Example:
422
+ >>> interleave_lists(['a1', 'a2', 'a3'], ['b1', 'b2'])
423
+ ['a1', 'b1', 'a2', 'b2', 'a3']
424
+ """
425
+ interleaved_list = []
426
+ for i in range(len(shorter)):
427
+ interleaved_list.append(longer[i])
428
+ interleaved_list.append(shorter[i])
429
+ interleaved_list.append(longer[-1]) # last element is from longer
430
+ return interleaved_list
431
+
432
+ device = audio_embeds[0].device
433
+
434
+ wrapped_embeds_list = []
435
+ wrapped_atts_list = []
436
+ batch_size = len(prompt)
437
+ for i in range(batch_size):
438
+ prompt_parts = prompt[i].split("<AudioHere>")
439
+ wrapped_embeds = []
440
+ wrapped_atts = []
441
+
442
+ for part in prompt_parts:
443
+ tokens = self.llama_tokenizer(part, return_tensors="pt", add_special_tokens=False).to(device)
444
+ part_embeds = self.llama_embed_tokens(tokens.input_ids).squeeze(0)
445
+ part_atts = tokens.attention_mask.squeeze(0)
446
+ wrapped_embeds.append(part_embeds)
447
+ wrapped_atts.append(part_atts)
448
+
449
+ # Process each element in the batch to remove padding
450
+ if self.max_pooling:
451
+ audio_embeds[i] = list(audio_embeds[i].unbind(0))
452
+ audio_atts[i] = list(audio_atts[i].unbind(0))
453
+ for j in range(len(audio_embeds[i])):
454
+ audio_embeds[i][j] = audio_embeds[i][j][audio_atts[i][j]]
455
+ audio_atts[i][j] = audio_atts[i][j][audio_atts[i][j]]
456
+
457
+ # Interleave wrapped_embeds and audio_embeds using interleave_lists
458
+ wrapped_embeds = interleave_lists(wrapped_embeds, audio_embeds[i])
459
+ wrapped_atts = interleave_lists(wrapped_atts, audio_atts[i])
460
+
461
+ wrapped_embeds = torch.cat(wrapped_embeds, dim=0)
462
+ wrapped_atts = torch.cat(wrapped_atts, dim=0)
463
+ wrapped_embeds_list.append(wrapped_embeds)
464
+ wrapped_atts_list.append(wrapped_atts)
465
+
466
+ wrapped_embeds = pad_sequence(wrapped_embeds_list, batch_first=True)
467
+ wrapped_atts = pad_sequence(wrapped_atts_list, batch_first=True)
468
+ return wrapped_embeds, wrapped_atts
469
+
470
+ def forward(self, samples, verbose=True):
471
+ # Prepare prompts
472
+ prompt = samples["prompt"]
473
+ prompt = [self.prompt_template.format(p) for p in prompt]
474
+
475
+ # Use audio/audio encoder to encode audio/audio
476
+ raw_wav = samples.get("raw_wav", None)
477
+ audio_padding_mask = samples.get("padding_mask", None)
478
+
479
+ audio_embeds, audio_atts = self.encode_audio(raw_wav, audio_padding_mask)
480
+ audio_chunk_sizes = samples["audio_chunk_sizes"]
481
+ split_audio_embeds = list(torch.split(audio_embeds, audio_chunk_sizes, dim=0))
482
+ split_audio_atts = list(torch.split(audio_atts, audio_chunk_sizes, dim=0))
483
+
484
+ # Wrap audio_embeds with prompts
485
+ audio_embeds, audio_atts = self.prompt_wrap(split_audio_embeds, split_audio_atts, prompt)
486
+
487
+ # Prepare inputs for LLM
488
+ text = [t + self.end_sym for t in samples["text"]]
489
+ to_regress_tokens = self.llama_tokenizer(
490
+ text,
491
+ return_tensors="pt",
492
+ padding="longest",
493
+ truncation=True,
494
+ max_length=self.max_txt_len,
495
+ add_special_tokens=False,
496
+ ).to(audio_embeds.device)
497
+
498
+ to_regress_embeds = self.llama_embed_tokens(to_regress_tokens.input_ids)
499
+
500
+ # Prepare targets
501
+ targets = to_regress_tokens.input_ids.masked_fill(
502
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
503
+ )
504
+
505
+ batch_size = audio_embeds.size(0)
506
+
507
+ # BOS token embeddings
508
+ bos_token_id = self.llama_tokenizer.bos_token_id
509
+ bos = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=audio_embeds.device)
510
+ bos_embeds = self.llama_embed_tokens(bos)
511
+
512
+ # Prepare lists to collect per-sample embeddings, attention masks, and targets
513
+ inputs_embeds_list = []
514
+ attention_mask_list = []
515
+ targets_list = []
516
+
517
+ for i in range(batch_size):
518
+ # Extract non-padded audio embeddings and attention mask
519
+ audio_embed = audio_embeds[i][audio_atts[i].bool()]
520
+ audio_att = audio_atts[i][audio_atts[i].bool()]
521
+
522
+ # Extract non-padded text embeddings and attention mask
523
+ text_embed = to_regress_embeds[i][to_regress_tokens.attention_mask[i].bool()]
524
+ text_att = to_regress_tokens.attention_mask[i][to_regress_tokens.attention_mask[i].bool()]
525
+
526
+ # Extract corresponding targets for the text tokens
527
+ target = targets[i][to_regress_tokens.attention_mask[i].bool()]
528
+
529
+ # Concatenate embeddings: BOS token, audio embeddings, text embeddings
530
+ input_embeds = torch.cat([bos_embeds[i], audio_embed, text_embed], dim=0)
531
+
532
+ # Concatenate attention masks: BOS token mask, audio attention mask, text attention mask
533
+ att_mask = torch.cat(
534
+ [
535
+ torch.ones(1, device=audio_embeds.device, dtype=audio_att.dtype),
536
+ audio_att,
537
+ text_att,
538
+ ],
539
+ dim=0,
540
+ )
541
+
542
+ # Create targets: Ignore index (-100) for BOS and audio tokens, actual targets for text tokens
543
+ ignore_targets = torch.full(
544
+ (1 + audio_embed.size(0),),
545
+ -100,
546
+ device=audio_embeds.device,
547
+ dtype=targets.dtype,
548
+ )
549
+ sample_targets = torch.cat([ignore_targets, target], dim=0)
550
+
551
+ # Append to lists
552
+ inputs_embeds_list.append(input_embeds)
553
+ attention_mask_list.append(att_mask)
554
+ targets_list.append(sample_targets)
555
+
556
+ # Pad sequences to the maximum length in the batch
557
+ inputs_embeds_padded = pad_sequence(inputs_embeds_list, batch_first=True)
558
+ attention_mask_padded = pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
559
+ targets_padded = pad_sequence(targets_list, batch_first=True, padding_value=-100)
560
+
561
+ # Now use the padded embeddings, attention masks, and targets in the model
562
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
563
+ outputs = self.llama_model(
564
+ inputs_embeds=inputs_embeds_padded,
565
+ attention_mask=attention_mask_padded,
566
+ return_dict=True,
567
+ labels=targets_padded,
568
+ )
569
+ loss = outputs.loss # Original batch loss
570
+
571
+ # Compute per-example loss
572
+ nvocab = self.llama_model.config.vocab_size
573
+ logits = outputs.logits
574
+
575
+ shift_logits = logits[..., :-1, :].contiguous()
576
+ shift_labels = targets_padded[..., 1:].contiguous()
577
+
578
+ # Compute loss per token
579
+ loss_fct_per_example = CrossEntropyLoss(reduction="none")
580
+ loss_per_token = loss_fct_per_example(
581
+ shift_logits.view(-1, nvocab), # Flatten to [batch_size * (seq_len-1), vocab_size]
582
+ shift_labels.view(-1), # Flatten to [batch_size * (seq_len-1)]
583
+ )
584
+ loss_per_token = loss_per_token.view(shift_labels.size()) # Reshape back to [batch_size, seq_len-1]
585
+
586
+ # Create mask
587
+ mask = shift_labels != -100 # [batch_size, seq_len-1]
588
+
589
+ # Apply mask to loss_per_token
590
+ loss_per_token = loss_per_token * mask.float()
591
+
592
+ # Compute per-example loss
593
+ loss_per_example = loss_per_token.sum(dim=1) / mask.sum(dim=1).clamp(min=1)
594
+
595
+ if verbose:
596
+ # Calculate predictions
597
+ predicted_tokens = shift_logits.argmax(dim=-1) # [batch_size, seq_len-1]
598
+
599
+ # Compute per-example correct counts
600
+ correct_per_sample = ((predicted_tokens == shift_labels) & mask).sum(dim=1).float() # [batch_size]
601
+ total_tokens_per_sample = mask.sum(dim=1).float() # [batch_size]
602
+
603
+ # Total correct and total tokens across the batch
604
+ correct = correct_per_sample.sum()
605
+ total = total_tokens_per_sample.sum()
606
+
607
+ return {
608
+ "loss": loss,
609
+ "correct": correct,
610
+ "total": total,
611
+ "per_example_loss": loss_per_example,
612
+ "correct_per_sample": correct_per_sample,
613
+ "total_per_sample": total_tokens_per_sample,
614
+ }
615
+
616
+ return {"loss": loss, "per_example_loss": loss_per_example}
617
+
618
+ @torch.inference_mode()
619
+ def generate(self, samples, generate_cfg, prompts):
620
+ batch_size = len(prompts)
621
+
622
+ raw_wav = samples["raw_wav"]
623
+ audio_padding_mask = samples.get("padding_mask", None)
624
+
625
+ audio_embeds, audio_atts = self.encode_audio(raw_wav, audio_padding_mask=audio_padding_mask)
626
+ split_audio_embeds = list(torch.split(audio_embeds, samples["audio_chunk_sizes"], dim=0))
627
+ split_audio_atts = list(torch.split(audio_atts, samples["audio_chunk_sizes"], dim=0))
628
+ audio_embeds, audio_atts = self.prompt_wrap(split_audio_embeds, split_audio_atts, prompts)
629
+ bos = (
630
+ torch.ones(
631
+ [batch_size, 1],
632
+ dtype=torch.int32,
633
+ device=audio_embeds.device,
634
+ )
635
+ * self.llama_tokenizer.bos_token_id
636
+ )
637
+ bos_embeds = self.llama_embed_tokens(bos)
638
+ atts_bos = audio_atts[:, :1]
639
+
640
+ embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
641
+
642
+ attns = torch.cat([atts_bos, audio_atts], dim=1)
643
+
644
+ stop_words_ids = [torch.tensor([2]).to(audio_embeds.device)]
645
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
646
+
647
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
648
+ outputs = self.llama_model.generate( # TODO: Wrap the llama_model with outlines https://outlines-dev.github.io/outlines/reference/models/transformers/
649
+ inputs_embeds=embeds.bfloat16(),
650
+ max_new_tokens=generate_cfg.max_new_tokens,
651
+ stopping_criteria=stopping_criteria,
652
+ num_beams=generate_cfg.num_beams,
653
+ do_sample=generate_cfg.do_sample,
654
+ min_length=generate_cfg.min_length,
655
+ temperature=generate_cfg.temperature,
656
+ # top_p=generate_cfg.get("top_p", 0.9),
657
+ repetition_penalty=generate_cfg.repetition_penalty,
658
+ length_penalty=generate_cfg.length_penalty,
659
+ attention_mask=attns.bfloat16(),
660
+ # prefix_allowed_tokens_fn=prefix_tokens_fn
661
+ # logits_processor=None
662
+ # constraints=[constraint] if constraint is not None else None
663
+ )
664
+ text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True)
665
+
666
+ return text
NatureLM/models/Qformer.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from salesforce@LAVIS. Below is the original copyright:
3
+ * Copyright (c) 2023, salesforce.com, inc.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ * By Junnan Li
8
+ * Based on huggingface code base
9
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
10
+ """
11
+
12
+ import math
13
+ from typing import Tuple
14
+
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from torch import Tensor, device, nn
18
+ from torch.nn import CrossEntropyLoss
19
+ from transformers.activations import ACT2FN
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutputWithPastAndCrossAttentions,
22
+ BaseModelOutputWithPoolingAndCrossAttentions,
23
+ CausalLMOutputWithCrossAttentions,
24
+ MaskedLMOutput,
25
+ )
26
+ from transformers.modeling_utils import (
27
+ PreTrainedModel,
28
+ apply_chunking_to_forward,
29
+ find_pruneable_heads_and_indices,
30
+ prune_linear_layer,
31
+ )
32
+ from transformers.models.bert.configuration_bert import BertConfig
33
+ from transformers.utils import logging
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ class BertEmbeddings(nn.Module):
39
+ """Construct the embeddings from word and position embeddings."""
40
+
41
+ def __init__(self, config):
42
+ super().__init__()
43
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
44
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
45
+
46
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
47
+ # any TensorFlow checkpoint file
48
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+
51
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
52
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
53
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
54
+
55
+ self.config = config
56
+
57
+ def forward(
58
+ self,
59
+ input_ids=None,
60
+ position_ids=None,
61
+ query_embeds=None,
62
+ past_key_values_length=0,
63
+ ):
64
+ if input_ids is not None:
65
+ seq_length = input_ids.size()[1]
66
+ else:
67
+ seq_length = 0
68
+
69
+ if position_ids is None:
70
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
71
+
72
+ if input_ids is not None:
73
+ embeddings = self.word_embeddings(input_ids)
74
+ if self.position_embedding_type == "absolute":
75
+ position_embeddings = self.position_embeddings(position_ids)
76
+ embeddings = embeddings + position_embeddings
77
+
78
+ if query_embeds is not None:
79
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
80
+ else:
81
+ embeddings = query_embeds
82
+
83
+ embeddings = self.LayerNorm(embeddings)
84
+ embeddings = self.dropout(embeddings)
85
+ return embeddings
86
+
87
+
88
+ class BertSelfAttention(nn.Module):
89
+ def __init__(self, config, is_cross_attention):
90
+ super().__init__()
91
+ self.config = config
92
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
93
+ raise ValueError(
94
+ "The hidden size (%d) is not a multiple of the number of attention "
95
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
96
+ )
97
+
98
+ self.num_attention_heads = config.num_attention_heads
99
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
100
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
101
+
102
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
103
+ if is_cross_attention:
104
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
105
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
106
+ else:
107
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
108
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
109
+
110
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
111
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
112
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
113
+ self.max_position_embeddings = config.max_position_embeddings
114
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
115
+ self.save_attention = False
116
+
117
+ def save_attn_gradients(self, attn_gradients):
118
+ self.attn_gradients = attn_gradients
119
+
120
+ def get_attn_gradients(self):
121
+ return self.attn_gradients
122
+
123
+ def save_attention_map(self, attention_map):
124
+ self.attention_map = attention_map
125
+
126
+ def get_attention_map(self):
127
+ return self.attention_map
128
+
129
+ def transpose_for_scores(self, x):
130
+ new_x_shape = x.size()[:-1] + (
131
+ self.num_attention_heads,
132
+ self.attention_head_size,
133
+ )
134
+ x = x.view(*new_x_shape)
135
+ return x.permute(0, 2, 1, 3)
136
+
137
+ def forward(
138
+ self,
139
+ hidden_states,
140
+ attention_mask=None,
141
+ head_mask=None,
142
+ encoder_hidden_states=None,
143
+ encoder_attention_mask=None,
144
+ past_key_value=None,
145
+ output_attentions=False,
146
+ ):
147
+ # If this is instantiated as a cross-attention module, the keys
148
+ # and values come from an encoder; the attention mask needs to be
149
+ # such that the encoder's padding tokens are not attended to.
150
+ is_cross_attention = encoder_hidden_states is not None
151
+
152
+ if is_cross_attention:
153
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
154
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
155
+ attention_mask = encoder_attention_mask
156
+ elif past_key_value is not None:
157
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
158
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
159
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
160
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
161
+ else:
162
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
163
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
164
+
165
+ mixed_query_layer = self.query(hidden_states)
166
+
167
+ query_layer = self.transpose_for_scores(mixed_query_layer)
168
+
169
+ past_key_value = (key_layer, value_layer)
170
+
171
+ # Take the dot product between "query" and "key" to get the raw attention scores.
172
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
173
+
174
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
175
+ seq_length = hidden_states.size()[1]
176
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
177
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
178
+ distance = position_ids_l - position_ids_r
179
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
180
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
181
+
182
+ if self.position_embedding_type == "relative_key":
183
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
184
+ attention_scores = attention_scores + relative_position_scores
185
+ elif self.position_embedding_type == "relative_key_query":
186
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
187
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
188
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
189
+
190
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
191
+ if attention_mask is not None:
192
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
193
+ attention_scores = attention_scores + attention_mask
194
+
195
+ # Normalize the attention scores to probabilities.
196
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
197
+
198
+ if is_cross_attention and self.save_attention:
199
+ self.save_attention_map(attention_probs)
200
+ attention_probs.register_hook(self.save_attn_gradients)
201
+
202
+ # This is actually dropping out entire tokens to attend to, which might
203
+ # seem a bit unusual, but is taken from the original Transformer paper.
204
+ attention_probs_dropped = self.dropout(attention_probs)
205
+
206
+ # Mask heads if we want to
207
+ if head_mask is not None:
208
+ attention_probs_dropped = attention_probs_dropped * head_mask
209
+
210
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
211
+
212
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
213
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
214
+ context_layer = context_layer.view(*new_context_layer_shape)
215
+
216
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
217
+
218
+ outputs = outputs + (past_key_value,)
219
+ return outputs
220
+
221
+
222
+ class BertSelfOutput(nn.Module):
223
+ def __init__(self, config):
224
+ super().__init__()
225
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
226
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
227
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
228
+
229
+ def forward(self, hidden_states, input_tensor):
230
+ hidden_states = self.dense(hidden_states)
231
+ hidden_states = self.dropout(hidden_states)
232
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
233
+ return hidden_states
234
+
235
+
236
+ class BertAttention(nn.Module):
237
+ def __init__(self, config, is_cross_attention=False):
238
+ super().__init__()
239
+ self.self = BertSelfAttention(config, is_cross_attention)
240
+ self.output = BertSelfOutput(config)
241
+ self.pruned_heads = set()
242
+
243
+ def prune_heads(self, heads):
244
+ if len(heads) == 0:
245
+ return
246
+ heads, index = find_pruneable_heads_and_indices(
247
+ heads,
248
+ self.self.num_attention_heads,
249
+ self.self.attention_head_size,
250
+ self.pruned_heads,
251
+ )
252
+
253
+ # Prune linear layers
254
+ self.self.query = prune_linear_layer(self.self.query, index)
255
+ self.self.key = prune_linear_layer(self.self.key, index)
256
+ self.self.value = prune_linear_layer(self.self.value, index)
257
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
258
+
259
+ # Update hyper params and store pruned heads
260
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
261
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
262
+ self.pruned_heads = self.pruned_heads.union(heads)
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states,
267
+ attention_mask=None,
268
+ head_mask=None,
269
+ encoder_hidden_states=None,
270
+ encoder_attention_mask=None,
271
+ past_key_value=None,
272
+ output_attentions=False,
273
+ ):
274
+ self_outputs = self.self(
275
+ hidden_states,
276
+ attention_mask,
277
+ head_mask,
278
+ encoder_hidden_states,
279
+ encoder_attention_mask,
280
+ past_key_value,
281
+ output_attentions,
282
+ )
283
+ attention_output = self.output(self_outputs[0], hidden_states)
284
+
285
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
286
+ return outputs
287
+
288
+
289
+ class BertIntermediate(nn.Module):
290
+ def __init__(self, config):
291
+ super().__init__()
292
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
293
+ if isinstance(config.hidden_act, str):
294
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
295
+ else:
296
+ self.intermediate_act_fn = config.hidden_act
297
+
298
+ def forward(self, hidden_states):
299
+ hidden_states = self.dense(hidden_states)
300
+ hidden_states = self.intermediate_act_fn(hidden_states)
301
+ return hidden_states
302
+
303
+
304
+ class BertOutput(nn.Module):
305
+ def __init__(self, config):
306
+ super().__init__()
307
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
308
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
309
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
310
+
311
+ def forward(self, hidden_states, input_tensor):
312
+ hidden_states = self.dense(hidden_states)
313
+ hidden_states = self.dropout(hidden_states)
314
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
315
+ return hidden_states
316
+
317
+
318
+ class BertLayer(nn.Module):
319
+ def __init__(self, config, layer_num):
320
+ super().__init__()
321
+ self.config = config
322
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
323
+ self.seq_len_dim = 1
324
+ self.attention = BertAttention(config)
325
+ self.layer_num = layer_num
326
+ if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
327
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
328
+ self.has_cross_attention = True
329
+ else:
330
+ self.has_cross_attention = False
331
+ self.intermediate = BertIntermediate(config)
332
+ self.output = BertOutput(config)
333
+
334
+ self.intermediate_query = BertIntermediate(config)
335
+ self.output_query = BertOutput(config)
336
+
337
+ def forward(
338
+ self,
339
+ hidden_states,
340
+ attention_mask=None,
341
+ head_mask=None,
342
+ encoder_hidden_states=None,
343
+ encoder_attention_mask=None,
344
+ past_key_value=None,
345
+ output_attentions=False,
346
+ query_length=0,
347
+ ):
348
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
349
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
350
+ self_attention_outputs = self.attention(
351
+ hidden_states,
352
+ attention_mask,
353
+ head_mask,
354
+ output_attentions=output_attentions,
355
+ past_key_value=self_attn_past_key_value,
356
+ )
357
+ attention_output = self_attention_outputs[0]
358
+ outputs = self_attention_outputs[1:-1]
359
+
360
+ present_key_value = self_attention_outputs[-1]
361
+
362
+ if query_length > 0:
363
+ query_attention_output = attention_output[:, :query_length, :]
364
+
365
+ if self.has_cross_attention:
366
+ assert (
367
+ encoder_hidden_states is not None
368
+ ), "encoder_hidden_states must be given for cross-attention layers"
369
+ cross_attention_outputs = self.crossattention(
370
+ query_attention_output,
371
+ attention_mask,
372
+ head_mask,
373
+ encoder_hidden_states,
374
+ encoder_attention_mask,
375
+ output_attentions=output_attentions,
376
+ )
377
+ query_attention_output = cross_attention_outputs[0]
378
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
379
+
380
+ layer_output = apply_chunking_to_forward(
381
+ self.feed_forward_chunk_query,
382
+ self.chunk_size_feed_forward,
383
+ self.seq_len_dim,
384
+ query_attention_output,
385
+ )
386
+ if attention_output.shape[1] > query_length:
387
+ layer_output_text = apply_chunking_to_forward(
388
+ self.feed_forward_chunk,
389
+ self.chunk_size_feed_forward,
390
+ self.seq_len_dim,
391
+ attention_output[:, query_length:, :],
392
+ )
393
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
394
+ else:
395
+ layer_output = apply_chunking_to_forward(
396
+ self.feed_forward_chunk,
397
+ self.chunk_size_feed_forward,
398
+ self.seq_len_dim,
399
+ attention_output,
400
+ )
401
+ outputs = (layer_output,) + outputs
402
+
403
+ outputs = outputs + (present_key_value,)
404
+
405
+ return outputs
406
+
407
+ def feed_forward_chunk(self, attention_output):
408
+ intermediate_output = self.intermediate(attention_output)
409
+ layer_output = self.output(intermediate_output, attention_output)
410
+ return layer_output
411
+
412
+ def feed_forward_chunk_query(self, attention_output):
413
+ intermediate_output = self.intermediate_query(attention_output)
414
+ layer_output = self.output_query(intermediate_output, attention_output)
415
+ return layer_output
416
+
417
+
418
+ class BertEncoder(nn.Module):
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.config = config
422
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
423
+
424
+ def forward(
425
+ self,
426
+ hidden_states,
427
+ attention_mask=None,
428
+ head_mask=None,
429
+ encoder_hidden_states=None,
430
+ encoder_attention_mask=None,
431
+ past_key_values=None,
432
+ use_cache=None,
433
+ output_attentions=False,
434
+ output_hidden_states=False,
435
+ return_dict=True,
436
+ query_length=0,
437
+ ):
438
+ all_hidden_states = () if output_hidden_states else None
439
+ all_self_attentions = () if output_attentions else None
440
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
441
+
442
+ next_decoder_cache = () if use_cache else None
443
+
444
+ for i in range(self.config.num_hidden_layers):
445
+ layer_module = self.layer[i]
446
+ if output_hidden_states:
447
+ all_hidden_states = all_hidden_states + (hidden_states,)
448
+
449
+ layer_head_mask = head_mask[i] if head_mask is not None else None
450
+ past_key_value = past_key_values[i] if past_key_values is not None else None
451
+
452
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
453
+ if use_cache:
454
+ logger.warn(
455
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
456
+ )
457
+ use_cache = False
458
+
459
+ def create_custom_forward(module):
460
+ def custom_forward(*inputs):
461
+ return module(*inputs, past_key_value, output_attentions, query_length)
462
+
463
+ return custom_forward
464
+
465
+ layer_outputs = torch.utils.checkpoint.checkpoint(
466
+ create_custom_forward(layer_module),
467
+ hidden_states,
468
+ attention_mask,
469
+ layer_head_mask,
470
+ encoder_hidden_states,
471
+ encoder_attention_mask,
472
+ )
473
+ else:
474
+ layer_outputs = layer_module(
475
+ hidden_states,
476
+ attention_mask,
477
+ layer_head_mask,
478
+ encoder_hidden_states,
479
+ encoder_attention_mask,
480
+ past_key_value,
481
+ output_attentions,
482
+ query_length,
483
+ )
484
+
485
+ hidden_states = layer_outputs[0]
486
+ if use_cache:
487
+ next_decoder_cache += (layer_outputs[-1],)
488
+ if output_attentions:
489
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
490
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
491
+
492
+ if output_hidden_states:
493
+ all_hidden_states = all_hidden_states + (hidden_states,)
494
+
495
+ if not return_dict:
496
+ return tuple(
497
+ v
498
+ for v in [
499
+ hidden_states,
500
+ next_decoder_cache,
501
+ all_hidden_states,
502
+ all_self_attentions,
503
+ all_cross_attentions,
504
+ ]
505
+ if v is not None
506
+ )
507
+ return BaseModelOutputWithPastAndCrossAttentions(
508
+ last_hidden_state=hidden_states,
509
+ past_key_values=next_decoder_cache,
510
+ hidden_states=all_hidden_states,
511
+ attentions=all_self_attentions,
512
+ cross_attentions=all_cross_attentions,
513
+ )
514
+
515
+
516
+ class BertPooler(nn.Module):
517
+ def __init__(self, config):
518
+ super().__init__()
519
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
520
+ self.activation = nn.Tanh()
521
+
522
+ def forward(self, hidden_states):
523
+ # We "pool" the model by simply taking the hidden state corresponding
524
+ # to the first token.
525
+ first_token_tensor = hidden_states[:, 0]
526
+ pooled_output = self.dense(first_token_tensor)
527
+ pooled_output = self.activation(pooled_output)
528
+ return pooled_output
529
+
530
+
531
+ class BertPredictionHeadTransform(nn.Module):
532
+ def __init__(self, config):
533
+ super().__init__()
534
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
535
+ if isinstance(config.hidden_act, str):
536
+ self.transform_act_fn = ACT2FN[config.hidden_act]
537
+ else:
538
+ self.transform_act_fn = config.hidden_act
539
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
540
+
541
+ def forward(self, hidden_states):
542
+ hidden_states = self.dense(hidden_states)
543
+ hidden_states = self.transform_act_fn(hidden_states)
544
+ hidden_states = self.LayerNorm(hidden_states)
545
+ return hidden_states
546
+
547
+
548
+ class BertLMPredictionHead(nn.Module):
549
+ def __init__(self, config):
550
+ super().__init__()
551
+ self.transform = BertPredictionHeadTransform(config)
552
+
553
+ # The output weights are the same as the input embeddings, but there is
554
+ # an output-only bias for each token.
555
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
556
+
557
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
558
+
559
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
560
+ self.decoder.bias = self.bias
561
+
562
+ def forward(self, hidden_states):
563
+ hidden_states = self.transform(hidden_states)
564
+ hidden_states = self.decoder(hidden_states)
565
+ return hidden_states
566
+
567
+
568
+ class BertOnlyMLMHead(nn.Module):
569
+ def __init__(self, config):
570
+ super().__init__()
571
+ self.predictions = BertLMPredictionHead(config)
572
+
573
+ def forward(self, sequence_output):
574
+ prediction_scores = self.predictions(sequence_output)
575
+ return prediction_scores
576
+
577
+
578
+ class BertPreTrainedModel(PreTrainedModel):
579
+ """
580
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
581
+ models.
582
+ """
583
+
584
+ config_class = BertConfig
585
+ base_model_prefix = "bert"
586
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
587
+
588
+ def _init_weights(self, module):
589
+ """Initialize the weights"""
590
+ if isinstance(module, (nn.Linear, nn.Embedding)):
591
+ # Slightly different from the TF version which uses truncated_normal for initialization
592
+ # cf https://github.com/pytorch/pytorch/pull/5617
593
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
594
+ elif isinstance(module, nn.LayerNorm):
595
+ module.bias.data.zero_()
596
+ module.weight.data.fill_(1.0)
597
+ if isinstance(module, nn.Linear) and module.bias is not None:
598
+ module.bias.data.zero_()
599
+
600
+
601
+ class BertModel(BertPreTrainedModel):
602
+ """
603
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
604
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
605
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
606
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
607
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
608
+ input to the forward pass.
609
+ """
610
+
611
+ def __init__(self, config, add_pooling_layer=False):
612
+ super().__init__(config)
613
+ self.config = config
614
+
615
+ self.embeddings = BertEmbeddings(config)
616
+
617
+ self.encoder = BertEncoder(config)
618
+
619
+ self.pooler = BertPooler(config) if add_pooling_layer else None
620
+
621
+ self.init_weights()
622
+
623
+ def get_input_embeddings(self):
624
+ return self.embeddings.word_embeddings
625
+
626
+ def set_input_embeddings(self, value):
627
+ self.embeddings.word_embeddings = value
628
+
629
+ def _prune_heads(self, heads_to_prune):
630
+ """
631
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
632
+ class PreTrainedModel
633
+ """
634
+ for layer, heads in heads_to_prune.items():
635
+ self.encoder.layer[layer].attention.prune_heads(heads)
636
+
637
+ def get_extended_attention_mask(
638
+ self,
639
+ attention_mask: Tensor,
640
+ input_shape: Tuple[int],
641
+ device: device,
642
+ is_decoder: bool,
643
+ has_query: bool = False,
644
+ ) -> Tensor:
645
+ """
646
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
647
+
648
+ Arguments:
649
+ attention_mask (:obj:`torch.Tensor`):
650
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
651
+ input_shape (:obj:`Tuple[int]`):
652
+ The shape of the input to the model.
653
+ device: (:obj:`torch.device`):
654
+ The device of the input to the model.
655
+
656
+ Returns:
657
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
658
+ """
659
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
660
+ # ourselves in which case we just need to make it broadcastable to all heads.
661
+ if attention_mask.dim() == 3:
662
+ extended_attention_mask = attention_mask[:, None, :, :]
663
+ elif attention_mask.dim() == 2:
664
+ # Provided a padding mask of dimensions [batch_size, seq_length]
665
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
666
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
667
+ if is_decoder:
668
+ batch_size, seq_length = input_shape
669
+
670
+ seq_ids = torch.arange(seq_length, device=device)
671
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
672
+
673
+ # add a prefix ones mask to the causal mask
674
+ # causal and attention masks must have same type with pytorch version < 1.3
675
+ causal_mask = causal_mask.to(attention_mask.dtype)
676
+
677
+ if causal_mask.shape[1] < attention_mask.shape[1]:
678
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
679
+ if has_query: # UniLM style attention mask
680
+ causal_mask = torch.cat(
681
+ [
682
+ torch.zeros(
683
+ (batch_size, prefix_seq_len, seq_length),
684
+ device=device,
685
+ dtype=causal_mask.dtype,
686
+ ),
687
+ causal_mask,
688
+ ],
689
+ axis=1,
690
+ )
691
+ causal_mask = torch.cat(
692
+ [
693
+ torch.ones(
694
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
695
+ device=device,
696
+ dtype=causal_mask.dtype,
697
+ ),
698
+ causal_mask,
699
+ ],
700
+ axis=-1,
701
+ )
702
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
703
+ else:
704
+ extended_attention_mask = attention_mask[:, None, None, :]
705
+ else:
706
+ raise ValueError(
707
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
708
+ input_shape, attention_mask.shape
709
+ )
710
+ )
711
+
712
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
713
+ # masked positions, this operation will create a tensor which is 0.0 for
714
+ # positions we want to attend and -10000.0 for masked positions.
715
+ # Since we are adding it to the raw scores before the softmax, this is
716
+ # effectively the same as removing these entirely.
717
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
718
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
719
+ return extended_attention_mask
720
+
721
+ def forward(
722
+ self,
723
+ input_ids=None,
724
+ attention_mask=None,
725
+ position_ids=None,
726
+ head_mask=None,
727
+ query_embeds=None,
728
+ encoder_hidden_states=None,
729
+ encoder_attention_mask=None,
730
+ past_key_values=None,
731
+ use_cache=None,
732
+ output_attentions=None,
733
+ output_hidden_states=None,
734
+ return_dict=None,
735
+ is_decoder=False,
736
+ ):
737
+ r"""
738
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
739
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
740
+ the model is configured as a decoder.
741
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
742
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
743
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
744
+ - 1 for tokens that are **not masked**,
745
+ - 0 for tokens that are **masked**.
746
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
747
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
748
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
749
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
750
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
751
+ use_cache (:obj:`bool`, `optional`):
752
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
753
+ decoding (see :obj:`past_key_values`).
754
+ """
755
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
756
+ output_hidden_states = (
757
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
758
+ )
759
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
760
+
761
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
762
+
763
+ if input_ids is None:
764
+ assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
765
+
766
+ # past_key_values_length
767
+ past_key_values_length = (
768
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
769
+ )
770
+
771
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
772
+
773
+ embedding_output = self.embeddings(
774
+ input_ids=input_ids,
775
+ position_ids=position_ids,
776
+ query_embeds=query_embeds,
777
+ past_key_values_length=past_key_values_length,
778
+ )
779
+
780
+ input_shape = embedding_output.size()[:-1]
781
+ batch_size, seq_length = input_shape
782
+ device = embedding_output.device
783
+
784
+ if attention_mask is None:
785
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
786
+
787
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
788
+ # ourselves in which case we just need to make it broadcastable to all heads.
789
+ if is_decoder:
790
+ extended_attention_mask = self.get_extended_attention_mask(
791
+ attention_mask,
792
+ input_ids.shape,
793
+ device,
794
+ is_decoder,
795
+ has_query=(query_embeds is not None),
796
+ )
797
+ else:
798
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
799
+
800
+ # If a 2D or 3D attention mask is provided for the cross-attention
801
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
802
+ if encoder_hidden_states is not None:
803
+ if isinstance(encoder_hidden_states, list):
804
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
805
+ else:
806
+ (
807
+ encoder_batch_size,
808
+ encoder_sequence_length,
809
+ _,
810
+ ) = encoder_hidden_states.size()
811
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
812
+
813
+ if isinstance(encoder_attention_mask, list):
814
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
815
+ elif encoder_attention_mask is None:
816
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
817
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
818
+ else:
819
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
820
+ else:
821
+ encoder_extended_attention_mask = None
822
+
823
+ # Prepare head mask if needed
824
+ # 1.0 in head_mask indicate we keep the head
825
+ # attention_probs has shape bsz x n_heads x N x N
826
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
827
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
828
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
829
+
830
+ encoder_outputs = self.encoder(
831
+ embedding_output,
832
+ attention_mask=extended_attention_mask,
833
+ head_mask=head_mask,
834
+ encoder_hidden_states=encoder_hidden_states,
835
+ encoder_attention_mask=encoder_extended_attention_mask,
836
+ past_key_values=past_key_values,
837
+ use_cache=use_cache,
838
+ output_attentions=output_attentions,
839
+ output_hidden_states=output_hidden_states,
840
+ return_dict=return_dict,
841
+ query_length=query_length,
842
+ )
843
+ sequence_output = encoder_outputs[0]
844
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
845
+
846
+ if not return_dict:
847
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
848
+
849
+ return BaseModelOutputWithPoolingAndCrossAttentions(
850
+ last_hidden_state=sequence_output,
851
+ pooler_output=pooled_output,
852
+ past_key_values=encoder_outputs.past_key_values,
853
+ hidden_states=encoder_outputs.hidden_states,
854
+ attentions=encoder_outputs.attentions,
855
+ cross_attentions=encoder_outputs.cross_attentions,
856
+ )
857
+
858
+
859
+ class BertLMHeadModel(BertPreTrainedModel):
860
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
861
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
862
+
863
+ def __init__(self, config):
864
+ super().__init__(config)
865
+
866
+ self.bert = BertModel(config, add_pooling_layer=False)
867
+ self.cls = BertOnlyMLMHead(config)
868
+
869
+ self.init_weights()
870
+
871
+ def get_output_embeddings(self):
872
+ return self.cls.predictions.decoder
873
+
874
+ def set_output_embeddings(self, new_embeddings):
875
+ self.cls.predictions.decoder = new_embeddings
876
+
877
+ def forward(
878
+ self,
879
+ input_ids=None,
880
+ attention_mask=None,
881
+ position_ids=None,
882
+ head_mask=None,
883
+ query_embeds=None,
884
+ encoder_hidden_states=None,
885
+ encoder_attention_mask=None,
886
+ labels=None,
887
+ past_key_values=None,
888
+ use_cache=True,
889
+ output_attentions=None,
890
+ output_hidden_states=None,
891
+ return_dict=None,
892
+ return_logits=False,
893
+ is_decoder=True,
894
+ reduction="mean",
895
+ ):
896
+ r"""
897
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
898
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
899
+ the model is configured as a decoder.
900
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
901
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
902
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
903
+ - 1 for tokens that are **not masked**,
904
+ - 0 for tokens that are **masked**.
905
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
906
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
907
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
908
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
909
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
910
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
911
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
912
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
913
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
914
+ use_cache (:obj:`bool`, `optional`):
915
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
916
+ decoding (see :obj:`past_key_values`).
917
+ Returns:
918
+ Example::
919
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
920
+ >>> import torch
921
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
922
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
923
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
924
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
925
+ >>> outputs = model(**inputs)
926
+ >>> prediction_logits = outputs.logits
927
+ """
928
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
929
+ if labels is not None:
930
+ use_cache = False
931
+ if past_key_values is not None:
932
+ query_embeds = None
933
+
934
+ outputs = self.bert(
935
+ input_ids,
936
+ attention_mask=attention_mask,
937
+ position_ids=position_ids,
938
+ head_mask=head_mask,
939
+ query_embeds=query_embeds,
940
+ encoder_hidden_states=encoder_hidden_states,
941
+ encoder_attention_mask=encoder_attention_mask,
942
+ past_key_values=past_key_values,
943
+ use_cache=use_cache,
944
+ output_attentions=output_attentions,
945
+ output_hidden_states=output_hidden_states,
946
+ return_dict=return_dict,
947
+ is_decoder=is_decoder,
948
+ )
949
+
950
+ sequence_output = outputs[0]
951
+ if query_embeds is not None:
952
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
953
+
954
+ prediction_scores = self.cls(sequence_output)
955
+
956
+ if return_logits:
957
+ return prediction_scores[:, :-1, :].contiguous()
958
+
959
+ lm_loss = None
960
+ if labels is not None:
961
+ # we are doing next-token prediction; shift prediction scores and input ids by one
962
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
963
+ labels = labels[:, 1:].contiguous()
964
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
965
+ lm_loss = loss_fct(
966
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
967
+ labels.view(-1),
968
+ )
969
+ if reduction == "none":
970
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
971
+
972
+ if not return_dict:
973
+ output = (prediction_scores,) + outputs[2:]
974
+ return ((lm_loss,) + output) if lm_loss is not None else output
975
+
976
+ return CausalLMOutputWithCrossAttentions(
977
+ loss=lm_loss,
978
+ logits=prediction_scores,
979
+ past_key_values=outputs.past_key_values,
980
+ hidden_states=outputs.hidden_states,
981
+ attentions=outputs.attentions,
982
+ cross_attentions=outputs.cross_attentions,
983
+ )
984
+
985
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
986
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
987
+ if attention_mask is None:
988
+ attention_mask = input_ids.new_ones(input_ids.shape)
989
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
990
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
991
+
992
+ # cut decoder_input_ids if past is used
993
+ if past is not None:
994
+ input_ids = input_ids[:, -1:]
995
+
996
+ return {
997
+ "input_ids": input_ids,
998
+ "query_embeds": query_embeds,
999
+ "attention_mask": attention_mask,
1000
+ "past_key_values": past,
1001
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1002
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1003
+ "is_decoder": True,
1004
+ }
1005
+
1006
+ def _reorder_cache(self, past, beam_idx):
1007
+ reordered_past = ()
1008
+ for layer_past in past:
1009
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1010
+ return reordered_past
1011
+
1012
+
1013
+ class BertForMaskedLM(BertPreTrainedModel):
1014
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1015
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1016
+
1017
+ def __init__(self, config):
1018
+ super().__init__(config)
1019
+
1020
+ self.bert = BertModel(config, add_pooling_layer=False)
1021
+ self.cls = BertOnlyMLMHead(config)
1022
+
1023
+ self.init_weights()
1024
+
1025
+ def get_output_embeddings(self):
1026
+ return self.cls.predictions.decoder
1027
+
1028
+ def set_output_embeddings(self, new_embeddings):
1029
+ self.cls.predictions.decoder = new_embeddings
1030
+
1031
+ def forward(
1032
+ self,
1033
+ input_ids=None,
1034
+ attention_mask=None,
1035
+ position_ids=None,
1036
+ head_mask=None,
1037
+ query_embeds=None,
1038
+ encoder_hidden_states=None,
1039
+ encoder_attention_mask=None,
1040
+ labels=None,
1041
+ output_attentions=None,
1042
+ output_hidden_states=None,
1043
+ return_dict=None,
1044
+ return_logits=False,
1045
+ is_decoder=False,
1046
+ ):
1047
+ r"""
1048
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1049
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1050
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1051
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1052
+ """
1053
+
1054
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1055
+
1056
+ outputs = self.bert(
1057
+ input_ids,
1058
+ attention_mask=attention_mask,
1059
+ position_ids=position_ids,
1060
+ head_mask=head_mask,
1061
+ query_embeds=query_embeds,
1062
+ encoder_hidden_states=encoder_hidden_states,
1063
+ encoder_attention_mask=encoder_attention_mask,
1064
+ output_attentions=output_attentions,
1065
+ output_hidden_states=output_hidden_states,
1066
+ return_dict=return_dict,
1067
+ is_decoder=is_decoder,
1068
+ )
1069
+
1070
+ if query_embeds is not None:
1071
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1072
+ prediction_scores = self.cls(sequence_output)
1073
+
1074
+ if return_logits:
1075
+ return prediction_scores
1076
+
1077
+ masked_lm_loss = None
1078
+ if labels is not None:
1079
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1080
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1081
+
1082
+ if not return_dict:
1083
+ output = (prediction_scores,) + outputs[2:]
1084
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1085
+
1086
+ return MaskedLMOutput(
1087
+ loss=masked_lm_loss,
1088
+ logits=prediction_scores,
1089
+ hidden_states=outputs.hidden_states,
1090
+ attentions=outputs.attentions,
1091
+ )
NatureLM/models/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Earth Species Project
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .NatureLM import NatureLM
16
+
17
+
18
+ def load_model(config):
19
+ return NatureLM.from_config(config)
NatureLM/models/__pycache__/NatureLM.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
NatureLM/models/__pycache__/Qformer.cpython-310.pyc ADDED
Binary file (30 kB). View file
 
NatureLM/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (329 Bytes). View file
 
NatureLM/models/__pycache__/utils.cpython-310.pyc ADDED
Binary file (926 Bytes). View file
 
NatureLM/models/aves.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchaudio.models import wav2vec2_model
7
+
8
+
9
+ class AvesEmbedding(nn.Module):
10
+ def __init__(self, sr, large=False):
11
+ super().__init__()
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # reference: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2vec2/utils/import_fairseq.html
15
+ if large:
16
+ config = self.load_config("configs/birdaves_bioxlarge.config")
17
+ else:
18
+ config = self.load_config("configs/birdaves_bioxbase.config")
19
+ self.model = wav2vec2_model(**config, aux_num_out=None)
20
+ state_dict = torch.hub.load_state_dict_from_url(
21
+ "https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-base.torchaudio.pt",
22
+ map_location=device,
23
+ )
24
+ self.model.load_state_dict(state_dict)
25
+ self.model.feature_extractor.requires_grad_(True)
26
+
27
+ # bundle = torchaudio.pipelines.WAV2VEC2_BASE
28
+ # self.model = bundle.get_model()
29
+
30
+ self.sr = sr
31
+
32
+ def load_config(self, config_path):
33
+ with open(config_path, "r") as ff:
34
+ obj = json.load(ff)
35
+
36
+ return obj
37
+
38
+ def forward(self, sig, padding_mask):
39
+ # extract_feature in the torchaudio version will output all 12 layers' output, -1 to select the final one
40
+ # print("sig", sig)
41
+
42
+ out = self.model.extract_features(sig.float())[0][-1]
43
+ atts = ~padding_mask
44
+ atts = atts.unsqueeze(1).float()
45
+ atts = F.max_pool1d(atts, kernel_size=320, stride=320)
46
+ atts = atts > 0
47
+ padding_mask = ~atts
48
+
49
+ return out, padding_mask
50
+
51
+ def freeze(self):
52
+ for param in self.model.encoder.parameters():
53
+ param.requires_grad = False
54
+ self.model.feature_extractor.requires_grad_(False)
55
+
56
+ def unfreeze(self):
57
+ for param in self.model.encoder.parameters():
58
+ param.requires_grad = True
59
+ self.model.feature_extractor.requires_grad_(True)
NatureLM/models/beats/BEATs.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import logging
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torchaudio.compliance.kaldi as ta_kaldi
17
+ from torch.nn import LayerNorm
18
+
19
+ from .backbone import TransformerEncoder
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class BEATsConfig:
25
+ def __init__(self, cfg=None):
26
+ self.input_patch_size: int = -1 # path size of patch embedding
27
+ self.embed_dim: int = 512 # patch embedding dimension
28
+ self.conv_bias: bool = False # include bias in conv encoder
29
+
30
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
31
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
32
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
33
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
34
+ self.activation_fn: str = "gelu" # activation function to use
35
+
36
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
37
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
38
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
39
+
40
+ # dropouts
41
+ self.dropout: float = 0.1 # dropout probability for the transformer
42
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
43
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
44
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
45
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
46
+
47
+ # positional embeddings
48
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
49
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
50
+
51
+ # relative position embedding
52
+ self.relative_position_embedding: bool = False # apply relative position embedding
53
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
54
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
55
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
56
+
57
+ # label predictor
58
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
59
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
60
+ self.predictor_class: int = 527 # target class number for the predictor
61
+
62
+ if cfg is not None:
63
+ self.update(cfg)
64
+
65
+ def update(self, cfg: dict):
66
+ self.__dict__.update(cfg)
67
+
68
+ def to_dict(self):
69
+ return self.__dict__
70
+
71
+
72
+ class BEATs(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: BEATsConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"BEATs Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
85
+ )
86
+
87
+ self.input_patch_size = cfg.input_patch_size
88
+ self.patch_embedding = nn.Conv2d(
89
+ 1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=cfg.conv_bias
90
+ )
91
+
92
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
93
+
94
+ assert not cfg.deep_norm or not cfg.layer_norm_first
95
+ self.encoder = TransformerEncoder(cfg)
96
+ self.layer_norm = LayerNorm(self.embed)
97
+
98
+ if cfg.finetuned_model:
99
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
100
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
101
+ else:
102
+ self.predictor = None
103
+
104
+ def forward_padding_mask(
105
+ self,
106
+ features: torch.Tensor,
107
+ padding_mask: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ extra = padding_mask.size(1) % features.size(1)
110
+ if extra > 0:
111
+ padding_mask = padding_mask[:, :-extra]
112
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
113
+ padding_mask = padding_mask.all(-1)
114
+ return padding_mask
115
+
116
+ def preprocess(
117
+ self,
118
+ source: torch.Tensor,
119
+ fbank_mean: float = 15.41663,
120
+ fbank_std: float = 6.55582,
121
+ ) -> torch.Tensor:
122
+ fbanks = []
123
+ for waveform in source:
124
+ waveform = waveform.unsqueeze(0) * 2**15
125
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
126
+ fbanks.append(fbank)
127
+ fbank = torch.stack(fbanks, dim=0)
128
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
129
+ return fbank
130
+
131
+ def extract_features(
132
+ self,
133
+ source: torch.Tensor,
134
+ padding_mask: Optional[torch.Tensor] = None,
135
+ fbank_mean: float = 15.41663,
136
+ fbank_std: float = 6.55582,
137
+ feature_only=False,
138
+ ):
139
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32)
140
+
141
+ if padding_mask is not None:
142
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
143
+
144
+ fbank = fbank.unsqueeze(1)
145
+ features = self.patch_embedding(fbank)
146
+ features = features.reshape(features.shape[0], features.shape[1], -1)
147
+ features = features.transpose(1, 2)
148
+ features = self.layer_norm(features)
149
+
150
+ if padding_mask is not None:
151
+ padding_mask = self.forward_padding_mask(features, padding_mask)
152
+
153
+ if self.post_extract_proj is not None:
154
+ features = self.post_extract_proj(features)
155
+
156
+ x = self.dropout_input(features)
157
+
158
+ x, layer_results = self.encoder(
159
+ x,
160
+ padding_mask=padding_mask,
161
+ )
162
+
163
+ if not feature_only and self.predictor is not None:
164
+ x = self.predictor_dropout(x)
165
+ logits = self.predictor(x)
166
+
167
+ if padding_mask is not None and padding_mask.any():
168
+ logits[padding_mask] = 0
169
+ logits = logits.sum(dim=1)
170
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
171
+ else:
172
+ logits = logits.mean(dim=1)
173
+
174
+ lprobs = torch.sigmoid(logits)
175
+
176
+ return lprobs, padding_mask
177
+ else:
178
+ return x, padding_mask
179
+
180
+ def forward(self, source: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
181
+ return self.extract_features(source, padding_mask, feature_only=True)
NatureLM/models/beats/Tokenizers.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import logging
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torchaudio.compliance.kaldi as ta_kaldi
17
+ from torch.nn import LayerNorm
18
+
19
+ from .backbone import (
20
+ TransformerEncoder,
21
+ )
22
+ from .quantizer import (
23
+ NormEMAVectorQuantizer,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class TokenizersConfig:
30
+ def __init__(self, cfg=None):
31
+ self.input_patch_size: int = -1 # path size of patch embedding
32
+ self.embed_dim: int = 512 # patch embedding dimension
33
+ self.conv_bias: bool = False # include bias in conv encoder
34
+
35
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
36
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
37
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
38
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
39
+ self.activation_fn: str = "gelu" # activation function to use
40
+
41
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
42
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
43
+
44
+ # dropouts
45
+ self.dropout: float = 0.1 # dropout probability for the transformer
46
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
47
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
48
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
49
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
50
+
51
+ # positional embeddings
52
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
53
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
54
+
55
+ # relative position embedding
56
+ self.relative_position_embedding: bool = False # apply relative position embedding
57
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
58
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
59
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
60
+
61
+ # quantizer
62
+ self.quant_n: int = 1024 # codebook number in quantizer
63
+ self.quant_dim: int = 256 # codebook dimension in quantizer
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class Tokenizers(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: TokenizersConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
85
+ )
86
+
87
+ self.input_patch_size = cfg.input_patch_size
88
+ self.patch_embedding = nn.Conv2d(
89
+ 1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=cfg.conv_bias
90
+ )
91
+
92
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
93
+
94
+ assert not cfg.deep_norm or not cfg.layer_norm_first
95
+ self.encoder = TransformerEncoder(cfg)
96
+ self.layer_norm = LayerNorm(self.embed)
97
+
98
+ self.quantize = NormEMAVectorQuantizer(
99
+ n_embed=cfg.quant_n,
100
+ embedding_dim=cfg.quant_dim,
101
+ beta=1.0,
102
+ kmeans_init=True,
103
+ decay=0.99,
104
+ )
105
+ self.quant_n = cfg.quant_n
106
+ self.quantize_layer = nn.Sequential(
107
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
108
+ nn.Tanh(),
109
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim), # for quantize
110
+ )
111
+
112
+ def forward_padding_mask(
113
+ self,
114
+ features: torch.Tensor,
115
+ padding_mask: torch.Tensor,
116
+ ) -> torch.Tensor:
117
+ extra = padding_mask.size(1) % features.size(1)
118
+ if extra > 0:
119
+ padding_mask = padding_mask[:, :-extra]
120
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
121
+ padding_mask = padding_mask.all(-1)
122
+ return padding_mask
123
+
124
+ def preprocess(
125
+ self,
126
+ source: torch.Tensor,
127
+ fbank_mean: float = 15.41663,
128
+ fbank_std: float = 6.55582,
129
+ ) -> torch.Tensor:
130
+ fbanks = []
131
+ for waveform in source:
132
+ waveform = waveform.unsqueeze(0) * 2**15
133
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
134
+ fbanks.append(fbank)
135
+ fbank = torch.stack(fbanks, dim=0)
136
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
137
+ return fbank
138
+
139
+ def extract_labels(
140
+ self,
141
+ source: torch.Tensor,
142
+ padding_mask: Optional[torch.Tensor] = None,
143
+ fbank_mean: float = 15.41663,
144
+ fbank_std: float = 6.55582,
145
+ ):
146
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
147
+
148
+ if padding_mask is not None:
149
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
150
+
151
+ fbank = fbank.unsqueeze(1)
152
+ features = self.patch_embedding(fbank)
153
+ features = features.reshape(features.shape[0], features.shape[1], -1)
154
+ features = features.transpose(1, 2)
155
+ features = self.layer_norm(features)
156
+
157
+ if padding_mask is not None:
158
+ padding_mask = self.forward_padding_mask(features, padding_mask)
159
+
160
+ if self.post_extract_proj is not None:
161
+ features = self.post_extract_proj(features)
162
+
163
+ x = self.dropout_input(features)
164
+
165
+ x, layer_results = self.encoder(
166
+ x,
167
+ padding_mask=padding_mask,
168
+ )
169
+
170
+ quantize_input = self.quantize_layer(x)
171
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
172
+
173
+ return embed_ind
NatureLM/models/beats/__init__.py ADDED
File without changes
NatureLM/models/beats/__pycache__/BEATs.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
NatureLM/models/beats/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (173 Bytes). View file
 
NatureLM/models/beats/__pycache__/backbone.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
NatureLM/models/beats/__pycache__/modules.cpython-310.pyc ADDED
Binary file (6.14 kB). View file
 
NatureLM/models/beats/backbone.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ from typing import Dict, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import Tensor, nn
17
+ from torch.nn import LayerNorm, Parameter
18
+
19
+ from .modules import (
20
+ GLU_Linear,
21
+ GradMultiply,
22
+ SamePad,
23
+ get_activation_fn,
24
+ quant_noise,
25
+ )
26
+
27
+
28
+ class TransformerEncoder(nn.Module):
29
+ def __init__(self, args):
30
+ super().__init__()
31
+
32
+ self.dropout = args.dropout
33
+ self.embedding_dim = args.encoder_embed_dim
34
+
35
+ self.pos_conv = nn.Conv1d(
36
+ self.embedding_dim,
37
+ self.embedding_dim,
38
+ kernel_size=args.conv_pos,
39
+ padding=args.conv_pos // 2,
40
+ groups=args.conv_pos_groups,
41
+ )
42
+ dropout = 0
43
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
44
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
45
+ nn.init.constant_(self.pos_conv.bias, 0)
46
+
47
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
48
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
49
+
50
+ if hasattr(args, "relative_position_embedding"):
51
+ self.relative_position_embedding = args.relative_position_embedding
52
+ self.num_buckets = args.num_buckets
53
+ self.max_distance = args.max_distance
54
+ else:
55
+ self.relative_position_embedding = False
56
+ self.num_buckets = 0
57
+ self.max_distance = 0
58
+
59
+ self.layers = nn.ModuleList(
60
+ [
61
+ TransformerSentenceEncoderLayer(
62
+ embedding_dim=self.embedding_dim,
63
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
64
+ num_attention_heads=args.encoder_attention_heads,
65
+ dropout=self.dropout,
66
+ attention_dropout=args.attention_dropout,
67
+ activation_dropout=args.activation_dropout,
68
+ activation_fn=args.activation_fn,
69
+ layer_norm_first=args.layer_norm_first,
70
+ deep_norm=args.deep_norm,
71
+ has_relative_attention_bias=self.relative_position_embedding,
72
+ num_buckets=self.num_buckets,
73
+ max_distance=self.max_distance,
74
+ gru_rel_pos=args.gru_rel_pos,
75
+ encoder_layers=args.encoder_layers,
76
+ )
77
+ for i in range(args.encoder_layers)
78
+ ]
79
+ )
80
+ if self.relative_position_embedding:
81
+ for i in range(1, args.encoder_layers):
82
+ del self.layers[i].self_attn.relative_attention_bias
83
+ self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
84
+
85
+ self.layer_norm_first = args.layer_norm_first
86
+ self.layer_norm = LayerNorm(self.embedding_dim)
87
+ self.layerdrop = args.encoder_layerdrop
88
+
89
+ self.apply(init_bert_params)
90
+
91
+ if args.deep_norm:
92
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
93
+ for i in range(args.encoder_layers):
94
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
95
+ nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
96
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
97
+ nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
98
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
99
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
100
+
101
+ self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
102
+
103
+ def forward(self, x, padding_mask=None, layer=None):
104
+ x, layer_results = self.extract_features(x, padding_mask, layer)
105
+
106
+ if self.layer_norm_first and layer is None:
107
+ x = self.layer_norm(x)
108
+
109
+ return x, layer_results
110
+
111
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
112
+ if padding_mask is not None:
113
+ x[padding_mask] = 0
114
+
115
+ x_conv = self.pos_conv(x.transpose(1, 2))
116
+ x_conv = x_conv.transpose(1, 2)
117
+ x = x + x_conv
118
+
119
+ if not self.layer_norm_first:
120
+ x = self.layer_norm(x)
121
+
122
+ x = F.dropout(x, p=self.dropout, training=self.training)
123
+
124
+ # B x T x C -> T x B x C
125
+ x = x.transpose(0, 1)
126
+
127
+ layer_results = []
128
+ z = None
129
+ if tgt_layer is not None:
130
+ layer_results.append((x, z))
131
+ r = None
132
+ pos_bias = None
133
+ for i, layer in enumerate(self.layers):
134
+ if self.layer_wise_gradient_decay_ratio != 1.0:
135
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
136
+ dropout_probability = np.random.random()
137
+ if not self.training or (dropout_probability > self.layerdrop):
138
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
139
+ if tgt_layer is not None:
140
+ layer_results.append((x, z))
141
+ if i == tgt_layer:
142
+ r = x
143
+ break
144
+
145
+ if r is not None:
146
+ x = r
147
+
148
+ # T x B x C -> B x T x C
149
+ x = x.transpose(0, 1)
150
+
151
+ return x, layer_results
152
+
153
+
154
+ class TransformerSentenceEncoderLayer(nn.Module):
155
+ def __init__(
156
+ self,
157
+ embedding_dim: float = 768,
158
+ ffn_embedding_dim: float = 3072,
159
+ num_attention_heads: float = 8,
160
+ dropout: float = 0.1,
161
+ attention_dropout: float = 0.1,
162
+ activation_dropout: float = 0.1,
163
+ activation_fn: str = "relu",
164
+ layer_norm_first: bool = False,
165
+ deep_norm: bool = False,
166
+ has_relative_attention_bias: bool = False,
167
+ num_buckets: int = 0,
168
+ max_distance: int = 0,
169
+ rescale_init: bool = False,
170
+ gru_rel_pos: bool = False,
171
+ encoder_layers: int = 0,
172
+ ) -> None:
173
+ super().__init__()
174
+ self.embedding_dim = embedding_dim
175
+ self.dropout = dropout
176
+ self.activation_dropout = activation_dropout
177
+
178
+ self.activation_name = activation_fn
179
+ self.activation_fn = get_activation_fn(activation_fn)
180
+ self.self_attn = MultiheadAttention(
181
+ self.embedding_dim,
182
+ num_attention_heads,
183
+ dropout=attention_dropout,
184
+ self_attention=True,
185
+ has_relative_attention_bias=has_relative_attention_bias,
186
+ num_buckets=num_buckets,
187
+ max_distance=max_distance,
188
+ rescale_init=rescale_init,
189
+ gru_rel_pos=gru_rel_pos,
190
+ )
191
+
192
+ self.dropout1 = nn.Dropout(dropout)
193
+ self.dropout2 = nn.Dropout(self.activation_dropout)
194
+ self.dropout3 = nn.Dropout(dropout)
195
+
196
+ self.layer_norm_first = layer_norm_first
197
+
198
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
199
+
200
+ if self.activation_name == "glu":
201
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
202
+ else:
203
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
204
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
205
+
206
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
207
+
208
+ self.deep_norm = deep_norm
209
+ if self.deep_norm:
210
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
211
+ else:
212
+ self.deep_norm_alpha = 1
213
+
214
+ def forward(
215
+ self,
216
+ x: torch.Tensor,
217
+ self_attn_mask: torch.Tensor = None,
218
+ self_attn_padding_mask: torch.Tensor = None,
219
+ need_weights: bool = False,
220
+ pos_bias=None,
221
+ ):
222
+ residual = x
223
+
224
+ if self.layer_norm_first:
225
+ x = self.self_attn_layer_norm(x)
226
+ x, attn, pos_bias = self.self_attn(
227
+ query=x,
228
+ key=x,
229
+ value=x,
230
+ key_padding_mask=self_attn_padding_mask,
231
+ need_weights=False,
232
+ attn_mask=self_attn_mask,
233
+ position_bias=pos_bias,
234
+ )
235
+ x = self.dropout1(x)
236
+ x = residual + x
237
+
238
+ residual = x
239
+ x = self.final_layer_norm(x)
240
+ if self.activation_name == "glu":
241
+ x = self.fc1(x)
242
+ else:
243
+ x = self.activation_fn(self.fc1(x))
244
+ x = self.dropout2(x)
245
+ x = self.fc2(x)
246
+ x = self.dropout3(x)
247
+ x = residual + x
248
+ else:
249
+ x, attn, pos_bias = self.self_attn(
250
+ query=x,
251
+ key=x,
252
+ value=x,
253
+ key_padding_mask=self_attn_padding_mask,
254
+ need_weights=need_weights,
255
+ attn_mask=self_attn_mask,
256
+ position_bias=pos_bias,
257
+ )
258
+
259
+ x = self.dropout1(x)
260
+ x = residual * self.deep_norm_alpha + x
261
+
262
+ x = self.self_attn_layer_norm(x)
263
+
264
+ residual = x
265
+ if self.activation_name == "glu":
266
+ x = self.fc1(x)
267
+ else:
268
+ x = self.activation_fn(self.fc1(x))
269
+ x = self.dropout2(x)
270
+ x = self.fc2(x)
271
+ x = self.dropout3(x)
272
+ x = residual * self.deep_norm_alpha + x
273
+ x = self.final_layer_norm(x)
274
+
275
+ return x, attn, pos_bias
276
+
277
+
278
+ class MultiheadAttention(nn.Module):
279
+ """Multi-headed attention.
280
+
281
+ See "Attention Is All You Need" for more details.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ embed_dim,
287
+ num_heads,
288
+ kdim=None,
289
+ vdim=None,
290
+ dropout=0.0,
291
+ bias=True,
292
+ add_bias_kv=False,
293
+ add_zero_attn=False,
294
+ self_attention=False,
295
+ encoder_decoder_attention=False,
296
+ q_noise=0.0,
297
+ qn_block_size=8,
298
+ has_relative_attention_bias=False,
299
+ num_buckets=32,
300
+ max_distance=128,
301
+ gru_rel_pos=False,
302
+ rescale_init=False,
303
+ ):
304
+ super().__init__()
305
+ self.embed_dim = embed_dim
306
+ self.kdim = kdim if kdim is not None else embed_dim
307
+ self.vdim = vdim if vdim is not None else embed_dim
308
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
309
+
310
+ self.num_heads = num_heads
311
+ self.dropout_module = nn.Dropout(dropout)
312
+
313
+ self.has_relative_attention_bias = has_relative_attention_bias
314
+ self.num_buckets = num_buckets
315
+ self.max_distance = max_distance
316
+ if self.has_relative_attention_bias:
317
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
318
+
319
+ self.head_dim = embed_dim // num_heads
320
+ self.q_head_dim = self.head_dim
321
+ self.k_head_dim = self.head_dim
322
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
323
+ self.scaling = self.head_dim**-0.5
324
+
325
+ self.self_attention = self_attention
326
+ self.encoder_decoder_attention = encoder_decoder_attention
327
+
328
+ assert not self.self_attention or self.qkv_same_dim, (
329
+ "Self-attention requires query, key and " "value to be of the same size"
330
+ )
331
+
332
+ k_bias = True
333
+ if rescale_init:
334
+ k_bias = False
335
+
336
+ k_embed_dim = embed_dim
337
+ q_embed_dim = embed_dim
338
+
339
+ self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
340
+ self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
341
+ self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
342
+
343
+ self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
344
+
345
+ if add_bias_kv:
346
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
347
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
348
+ else:
349
+ self.bias_k = self.bias_v = None
350
+
351
+ self.add_zero_attn = add_zero_attn
352
+
353
+ self.gru_rel_pos = gru_rel_pos
354
+ if self.gru_rel_pos:
355
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
356
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
357
+
358
+ self.reset_parameters()
359
+
360
+ def reset_parameters(self):
361
+ if self.qkv_same_dim:
362
+ # Empirically observed the convergence to be much better with
363
+ # the scaled initialization
364
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
365
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
366
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
367
+ else:
368
+ nn.init.xavier_uniform_(self.k_proj.weight)
369
+ nn.init.xavier_uniform_(self.v_proj.weight)
370
+ nn.init.xavier_uniform_(self.q_proj.weight)
371
+
372
+ nn.init.xavier_uniform_(self.out_proj.weight)
373
+ if self.out_proj.bias is not None:
374
+ nn.init.constant_(self.out_proj.bias, 0.0)
375
+ if self.bias_k is not None:
376
+ nn.init.xavier_normal_(self.bias_k)
377
+ if self.bias_v is not None:
378
+ nn.init.xavier_normal_(self.bias_v)
379
+ if self.has_relative_attention_bias:
380
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
381
+
382
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
383
+ num_buckets = self.num_buckets
384
+ max_distance = self.max_distance
385
+ relative_buckets = 0
386
+
387
+ if bidirectional:
388
+ num_buckets = num_buckets // 2
389
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
390
+ relative_positions = torch.abs(relative_positions)
391
+ else:
392
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
393
+
394
+ max_exact = num_buckets // 2
395
+ is_small = relative_positions < max_exact
396
+
397
+ relative_postion_if_large = max_exact + (
398
+ torch.log(relative_positions.float() / max_exact)
399
+ / math.log(max_distance / max_exact)
400
+ * (num_buckets - max_exact)
401
+ ).to(torch.long)
402
+ relative_postion_if_large = torch.min(
403
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
404
+ )
405
+
406
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
407
+ return relative_buckets
408
+
409
+ def compute_bias(self, query_length, key_length):
410
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
411
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
412
+ relative_position = memory_position - context_position
413
+ relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
414
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
415
+ values = self.relative_attention_bias(relative_position_bucket)
416
+ values = values.permute([2, 0, 1])
417
+ return values
418
+
419
+ def forward(
420
+ self,
421
+ query,
422
+ key: Optional[Tensor],
423
+ value: Optional[Tensor],
424
+ key_padding_mask: Optional[Tensor] = None,
425
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
426
+ need_weights: bool = True,
427
+ static_kv: bool = False,
428
+ attn_mask: Optional[Tensor] = None,
429
+ before_softmax: bool = False,
430
+ need_head_weights: bool = False,
431
+ position_bias: Optional[Tensor] = None,
432
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
433
+ """Input shape: Time x Batch x Channel
434
+
435
+ Args:
436
+ key_padding_mask (ByteTensor, optional): mask to exclude
437
+ keys that are pads, of shape `(batch, src_len)`, where
438
+ padding elements are indicated by 1s.
439
+ need_weights (bool, optional): return the attention weights,
440
+ averaged over heads (default: False).
441
+ attn_mask (ByteTensor, optional): typically used to
442
+ implement causal attention, where the mask prevents the
443
+ attention from looking forward in time (default: None).
444
+ before_softmax (bool, optional): return the raw attention
445
+ weights and values before the attention softmax.
446
+ need_head_weights (bool, optional): return the attention
447
+ weights for each head. Implies *need_weights*. Default:
448
+ return the average attention weights over all heads.
449
+ """
450
+ if need_head_weights:
451
+ need_weights = True
452
+
453
+ is_tpu = query.device.type == "xla"
454
+
455
+ tgt_len, bsz, embed_dim = query.size()
456
+ src_len = tgt_len
457
+ assert embed_dim == self.embed_dim
458
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
459
+ if key is not None:
460
+ src_len, key_bsz, _ = key.size()
461
+ if not torch.jit.is_scripting():
462
+ assert key_bsz == bsz
463
+ assert value is not None
464
+ assert src_len, bsz == value.shape[:2]
465
+
466
+ if self.has_relative_attention_bias and position_bias is None:
467
+ position_bias = self.compute_bias(tgt_len, src_len)
468
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
469
+
470
+ if incremental_state is not None:
471
+ saved_state = self._get_input_buffer(incremental_state)
472
+ if saved_state is not None and "prev_key" in saved_state:
473
+ # previous time steps are cached - no need to recompute
474
+ # key and value if they are static
475
+ if static_kv:
476
+ assert self.encoder_decoder_attention and not self.self_attention
477
+ key = value = None
478
+ else:
479
+ saved_state = None
480
+
481
+ if self.self_attention:
482
+ q = self.q_proj(query)
483
+ k = self.k_proj(query)
484
+ v = self.v_proj(query)
485
+ elif self.encoder_decoder_attention:
486
+ # encoder-decoder attention
487
+ q = self.q_proj(query)
488
+ if key is None:
489
+ assert value is None
490
+ k = v = None
491
+ else:
492
+ k = self.k_proj(key)
493
+ v = self.v_proj(key)
494
+
495
+ else:
496
+ assert key is not None and value is not None
497
+ q = self.q_proj(query)
498
+ k = self.k_proj(key)
499
+ v = self.v_proj(value)
500
+ q *= self.scaling
501
+ alpha = 32
502
+ q *= 1 / alpha
503
+
504
+ if self.bias_k is not None:
505
+ assert self.bias_v is not None
506
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
507
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
508
+ if attn_mask is not None:
509
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
510
+ if key_padding_mask is not None:
511
+ key_padding_mask = torch.cat(
512
+ [
513
+ key_padding_mask,
514
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
515
+ ],
516
+ dim=1,
517
+ )
518
+
519
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
520
+ if k is not None:
521
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
522
+ if v is not None:
523
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
524
+
525
+ if saved_state is not None:
526
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
527
+ if "prev_key" in saved_state:
528
+ _prev_key = saved_state["prev_key"]
529
+ assert _prev_key is not None
530
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
531
+ if static_kv:
532
+ k = prev_key
533
+ else:
534
+ assert k is not None
535
+ k = torch.cat([prev_key, k], dim=1)
536
+ src_len = k.size(1)
537
+ if "prev_value" in saved_state:
538
+ _prev_value = saved_state["prev_value"]
539
+ assert _prev_value is not None
540
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
541
+ if static_kv:
542
+ v = prev_value
543
+ else:
544
+ assert v is not None
545
+ v = torch.cat([prev_value, v], dim=1)
546
+ prev_key_padding_mask: Optional[Tensor] = None
547
+ if "prev_key_padding_mask" in saved_state:
548
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
549
+ assert k is not None and v is not None
550
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
551
+ key_padding_mask=key_padding_mask,
552
+ prev_key_padding_mask=prev_key_padding_mask,
553
+ batch_size=bsz,
554
+ src_len=k.size(1),
555
+ static_kv=static_kv,
556
+ )
557
+
558
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
559
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
560
+ saved_state["prev_key_padding_mask"] = key_padding_mask
561
+ # In this branch incremental_state is never None
562
+ assert incremental_state is not None
563
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
564
+ assert k is not None
565
+ assert k.size(1) == src_len
566
+
567
+ # This is part of a workaround to get around fork/join parallelism
568
+ # not supporting Optional types.
569
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
570
+ key_padding_mask = None
571
+
572
+ if key_padding_mask is not None:
573
+ assert key_padding_mask.size(0) == bsz
574
+ assert key_padding_mask.size(1) == src_len
575
+
576
+ if self.add_zero_attn:
577
+ assert v is not None
578
+ src_len += 1
579
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
580
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
581
+ if attn_mask is not None:
582
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
583
+ if key_padding_mask is not None:
584
+ key_padding_mask = torch.cat(
585
+ [
586
+ key_padding_mask,
587
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
588
+ ],
589
+ dim=1,
590
+ )
591
+
592
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
593
+ attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
594
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
595
+
596
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
597
+
598
+ if attn_mask is not None:
599
+ attn_mask = attn_mask.unsqueeze(0)
600
+ attn_weights += attn_mask
601
+
602
+ if key_padding_mask is not None:
603
+ # don't attend to padding symbols
604
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
605
+ if not is_tpu:
606
+ attn_weights = attn_weights.masked_fill(
607
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
608
+ float("-inf"),
609
+ )
610
+ else:
611
+ attn_weights = attn_weights.transpose(0, 2)
612
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
613
+ attn_weights = attn_weights.transpose(0, 2)
614
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
615
+
616
+ if before_softmax:
617
+ return attn_weights, v, position_bias
618
+
619
+ if position_bias is not None:
620
+ attn_mask_rel_pos = position_bias
621
+ if self.gru_rel_pos == 1:
622
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
623
+ _B, _H, _L, __ = query_layer.size()
624
+ gate_a, gate_b = torch.sigmoid(
625
+ self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
626
+ ).chunk(2, dim=-1)
627
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
628
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
629
+
630
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
631
+
632
+ attn_weights = attn_weights + attn_mask_rel_pos
633
+
634
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
635
+ attn_weights = attn_weights_float.type_as(attn_weights)
636
+ attn_probs = self.dropout_module(attn_weights)
637
+
638
+ assert v is not None
639
+ attn = torch.bmm(attn_probs, v)
640
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
641
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
642
+ attn = self.out_proj(attn)
643
+ attn_weights: Optional[Tensor] = None
644
+ if need_weights:
645
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
646
+ if not need_head_weights:
647
+ # average attention weights over heads
648
+ attn_weights = attn_weights.mean(dim=0)
649
+
650
+ return attn, attn_weights, position_bias
651
+
652
+ @staticmethod
653
+ def _append_prev_key_padding_mask(
654
+ key_padding_mask: Optional[Tensor],
655
+ prev_key_padding_mask: Optional[Tensor],
656
+ batch_size: int,
657
+ src_len: int,
658
+ static_kv: bool,
659
+ ) -> Optional[Tensor]:
660
+ # saved key padding masks have shape (bsz, seq_len)
661
+ if prev_key_padding_mask is not None and static_kv:
662
+ new_key_padding_mask = prev_key_padding_mask
663
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
664
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
665
+ # During incremental decoding, as the padding token enters and
666
+ # leaves the frame, there will be a time when prev or current
667
+ # is None
668
+ elif prev_key_padding_mask is not None:
669
+ if src_len > prev_key_padding_mask.size(1):
670
+ filler = torch.zeros(
671
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
672
+ device=prev_key_padding_mask.device,
673
+ )
674
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
675
+ else:
676
+ new_key_padding_mask = prev_key_padding_mask.float()
677
+ elif key_padding_mask is not None:
678
+ if src_len > key_padding_mask.size(1):
679
+ filler = torch.zeros(
680
+ (batch_size, src_len - key_padding_mask.size(1)),
681
+ device=key_padding_mask.device,
682
+ )
683
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
684
+ else:
685
+ new_key_padding_mask = key_padding_mask.float()
686
+ else:
687
+ new_key_padding_mask = prev_key_padding_mask
688
+ return new_key_padding_mask
689
+
690
+ def _get_input_buffer(
691
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
692
+ ) -> Dict[str, Optional[Tensor]]:
693
+ result = self.get_incremental_state(incremental_state, "attn_state")
694
+ if result is not None:
695
+ return result
696
+ else:
697
+ empty_result: Dict[str, Optional[Tensor]] = {}
698
+ return empty_result
699
+
700
+ def _set_input_buffer(
701
+ self,
702
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
703
+ buffer: Dict[str, Optional[Tensor]],
704
+ ):
705
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
706
+
707
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
708
+ return attn_weights
709
+
710
+
711
+ def init_bert_params(module):
712
+ """
713
+ Initialize the weights specific to the BERT Model.
714
+ This overrides the default initializations depending on the specified arguments.
715
+ 1. If normal_init_linear_weights is set then weights of linear
716
+ layer will be initialized using the normal distribution and
717
+ bais will be set to the specified value.
718
+ 2. If normal_init_embed_weights is set then weights of embedding
719
+ layer will be initialized using the normal distribution.
720
+ 3. If normal_init_proj_weights is set then weights of
721
+ in_project_weight for MultiHeadAttention initialized using
722
+ the normal distribution (to be validated).
723
+ """
724
+
725
+ def normal_(data):
726
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
727
+ # so that the RNG is consistent with and without FSDP
728
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
729
+
730
+ if isinstance(module, nn.Linear):
731
+ normal_(module.weight.data)
732
+ if module.bias is not None:
733
+ module.bias.data.zero_()
734
+ if isinstance(module, nn.Embedding):
735
+ normal_(module.weight.data)
736
+ if module.padding_idx is not None:
737
+ module.weight.data[module.padding_idx].zero_()
738
+ if isinstance(module, MultiheadAttention):
739
+ normal_(module.q_proj.weight.data)
740
+ normal_(module.k_proj.weight.data)
741
+ normal_(module.v_proj.weight.data)
NatureLM/models/beats/modules.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn
16
+
17
+
18
+ class GradMultiply(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, x, scale):
21
+ ctx.scale = scale
22
+ res = x.new(x)
23
+ return res
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad):
27
+ return grad * ctx.scale, None
28
+
29
+
30
+ class SamePad(nn.Module):
31
+ def __init__(self, kernel_size, causal=False):
32
+ super().__init__()
33
+ if causal:
34
+ self.remove = kernel_size - 1
35
+ else:
36
+ self.remove = 1 if kernel_size % 2 == 0 else 0
37
+
38
+ def forward(self, x):
39
+ if self.remove > 0:
40
+ x = x[:, :, : -self.remove]
41
+ return x
42
+
43
+
44
+ class Swish(nn.Module):
45
+ def __init__(self):
46
+ super(Swish, self).__init__()
47
+ self.act = torch.nn.Sigmoid()
48
+
49
+ def forward(self, x):
50
+ return x * self.act(x)
51
+
52
+
53
+ class GLU_Linear(nn.Module):
54
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
55
+ super(GLU_Linear, self).__init__()
56
+
57
+ self.glu_type = glu_type
58
+ self.output_dim = output_dim
59
+
60
+ if glu_type == "sigmoid":
61
+ self.glu_act = torch.nn.Sigmoid()
62
+ elif glu_type == "swish":
63
+ self.glu_act = Swish()
64
+ elif glu_type == "relu":
65
+ self.glu_act = torch.nn.ReLU()
66
+ elif glu_type == "gelu":
67
+ self.glu_act = torch.nn.GELU()
68
+
69
+ if bias_in_glu:
70
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
71
+ else:
72
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
73
+
74
+ def forward(self, x):
75
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
76
+ x = self.linear(x)
77
+
78
+ if self.glu_type == "bilinear":
79
+ x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
80
+ else:
81
+ x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
82
+
83
+ return x
84
+
85
+
86
+ def gelu_accurate(x):
87
+ if not hasattr(gelu_accurate, "_a"):
88
+ gelu_accurate._a = math.sqrt(2 / math.pi)
89
+ return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
90
+
91
+
92
+ def gelu(x: torch.Tensor) -> torch.Tensor:
93
+ return torch.nn.functional.gelu(x.float()).type_as(x)
94
+
95
+
96
+ def get_activation_fn(activation: str):
97
+ """Returns the activation function corresponding to `activation`"""
98
+
99
+ if activation == "relu":
100
+ return F.relu
101
+ elif activation == "gelu":
102
+ return gelu
103
+ elif activation == "gelu_fast":
104
+ warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
105
+ return gelu_accurate
106
+ elif activation == "gelu_accurate":
107
+ return gelu_accurate
108
+ elif activation == "tanh":
109
+ return torch.tanh
110
+ elif activation == "linear":
111
+ return lambda x: x
112
+ elif activation == "glu":
113
+ return lambda x: x
114
+ else:
115
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
116
+
117
+
118
+ def quant_noise(module, p, block_size):
119
+ """
120
+ Wraps modules and applies quantization noise to the weights for
121
+ subsequent quantization with Iterative Product Quantization as
122
+ described in "Training with Quantization Noise for Extreme Model Compression"
123
+
124
+ Args:
125
+ - module: nn.Module
126
+ - p: amount of Quantization Noise
127
+ - block_size: size of the blocks for subsequent quantization with iPQ
128
+
129
+ Remarks:
130
+ - Module weights must have the right sizes wrt the block size
131
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
132
+ - For more detail on how to quantize by blocks with convolutional weights,
133
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
134
+ - We implement the simplest form of noise here as stated in the paper
135
+ which consists in randomly dropping blocks
136
+ """
137
+
138
+ # if no quantization noise, don't register hook
139
+ if p <= 0:
140
+ return module
141
+
142
+ # supported modules
143
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
144
+
145
+ # test whether module.weight has the right sizes wrt block_size
146
+ is_conv = module.weight.ndim == 4
147
+
148
+ # 2D matrix
149
+ if not is_conv:
150
+ assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
151
+
152
+ # 4D matrix
153
+ else:
154
+ # 1x1 convolutions
155
+ if module.kernel_size == (1, 1):
156
+ assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
157
+ # regular convolutions
158
+ else:
159
+ k = module.kernel_size[0] * module.kernel_size[1]
160
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
161
+
162
+ def _forward_pre_hook(mod, input):
163
+ # no noise for evaluation
164
+ if mod.training:
165
+ if not is_conv:
166
+ # gather weight and sizes
167
+ weight = mod.weight
168
+ in_features = weight.size(1)
169
+ out_features = weight.size(0)
170
+
171
+ # split weight matrix into blocks and randomly drop selected blocks
172
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
173
+ mask.bernoulli_(p)
174
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
175
+
176
+ else:
177
+ # gather weight and sizes
178
+ weight = mod.weight
179
+ in_channels = mod.in_channels
180
+ out_channels = mod.out_channels
181
+
182
+ # split weight matrix into blocks and randomly drop selected blocks
183
+ if mod.kernel_size == (1, 1):
184
+ mask = torch.zeros(
185
+ int(in_channels // block_size * out_channels),
186
+ device=weight.device,
187
+ )
188
+ mask.bernoulli_(p)
189
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
190
+ else:
191
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
192
+ mask.bernoulli_(p)
193
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
194
+
195
+ # scale weights and apply mask
196
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
197
+ s = 1 / (1 - p)
198
+ mod.weight.data = s * weight.masked_fill(mask, 0)
199
+
200
+ module.register_forward_pre_hook(_forward_pre_hook)
201
+ return module
NatureLM/models/beats/quantizer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+
10
+ import torch
11
+ import torch.distributed as distributed
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ try:
16
+ from einops import rearrange, repeat
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ def l2norm(t):
22
+ return F.normalize(t, p=2, dim=-1)
23
+
24
+
25
+ def ema_inplace(moving_avg, new, decay):
26
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
50
+ dists = -(diffs**2).sum(dim=-1)
51
+
52
+ buckets = dists.max(dim=-1).indices
53
+ bins = torch.bincount(buckets, minlength=num_clusters)
54
+ zero_mask = bins == 0
55
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
56
+
57
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
58
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
59
+ new_means = new_means / bins_min_clamped[..., None]
60
+
61
+ if use_cosine_sim:
62
+ new_means = l2norm(new_means)
63
+
64
+ means = torch.where(zero_mask[..., None], means, new_means)
65
+
66
+ return means, bins
67
+
68
+
69
+ class EmbeddingEMA(nn.Module):
70
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=""):
71
+ super().__init__()
72
+ self.num_tokens = num_tokens
73
+ self.codebook_dim = codebook_dim
74
+ self.decay = decay
75
+ self.eps = eps
76
+ if codebook_init_path == "":
77
+ if not kmeans_init:
78
+ weight = torch.randn(num_tokens, codebook_dim)
79
+ weight = l2norm(weight)
80
+ else:
81
+ weight = torch.zeros(num_tokens, codebook_dim)
82
+ self.register_buffer("initted", torch.Tensor([not kmeans_init]))
83
+ else:
84
+ print(f"load init codebook weight from {codebook_init_path}")
85
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location="cpu")
86
+ weight = codebook_ckpt_weight.clone()
87
+ self.register_buffer("initted", torch.Tensor([True]))
88
+
89
+ self.weight = nn.Parameter(weight, requires_grad=False)
90
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
91
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
92
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
93
+ self.update = True
94
+
95
+ @torch.jit.ignore
96
+ def init_embed_(self, data):
97
+ if self.initted:
98
+ return
99
+ print("Performing Kemans init for codebook")
100
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
101
+ self.weight.data.copy_(embed)
102
+ self.cluster_size.data.copy_(cluster_size)
103
+ self.initted.data.copy_(torch.Tensor([True]))
104
+
105
+ def forward(self, embed_id):
106
+ return F.embedding(embed_id, self.weight)
107
+
108
+ def cluster_size_ema_update(self, new_cluster_size):
109
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
110
+
111
+ def embed_avg_ema_update(self, new_embed_avg):
112
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
113
+
114
+ def weight_update(self, num_tokens):
115
+ n = self.cluster_size.sum()
116
+ smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
117
+ # normalize embedding average with smoothed cluster size
118
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
119
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
120
+ self.weight.data.copy_(embed_normalized)
121
+
122
+
123
+ def norm_ema_inplace(moving_avg, new, decay):
124
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
125
+ moving_avg.data.copy_(l2norm(moving_avg.data))
126
+
127
+
128
+ class NormEMAVectorQuantizer(nn.Module):
129
+ def __init__(
130
+ self,
131
+ n_embed,
132
+ embedding_dim,
133
+ beta,
134
+ decay=0.99,
135
+ eps=1e-5,
136
+ statistic_code_usage=True,
137
+ kmeans_init=False,
138
+ codebook_init_path="",
139
+ ):
140
+ super().__init__()
141
+ self.codebook_dim = embedding_dim
142
+ self.num_tokens = n_embed
143
+ self.beta = beta
144
+ self.decay = decay
145
+
146
+ # learnable = True if orthogonal_reg_weight > 0 else False
147
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
148
+
149
+ self.statistic_code_usage = statistic_code_usage
150
+ if statistic_code_usage:
151
+ self.register_buffer("cluster_size", torch.zeros(n_embed))
152
+ if distributed.is_available() and distributed.is_initialized():
153
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
154
+ self.all_reduce_fn = distributed.all_reduce
155
+ else:
156
+ self.all_reduce_fn = nn.Identity()
157
+
158
+ def reset_cluster_size(self, device):
159
+ if self.statistic_code_usage:
160
+ self.register_buffer("cluster_size", torch.zeros(self.num_tokens))
161
+ self.cluster_size = self.cluster_size.to(device)
162
+
163
+ def forward(self, z):
164
+ # reshape z -> (batch, height, width, channel) and flatten
165
+ # z, 'b c h w -> b h w c'
166
+ # z = rearrange(z, 'b c h w -> b h w c')
167
+ # z = z.transpose(1, 2)
168
+ z = l2norm(z)
169
+ z_flattened = z.reshape(-1, self.codebook_dim)
170
+
171
+ self.embedding.init_embed_(z_flattened)
172
+
173
+ d = (
174
+ z_flattened.pow(2).sum(dim=1, keepdim=True)
175
+ + self.embedding.weight.pow(2).sum(dim=1)
176
+ - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
177
+ ) # 'n d -> d n'
178
+
179
+ encoding_indices = torch.argmin(d, dim=1)
180
+
181
+ z_q = self.embedding(encoding_indices).view(z.shape)
182
+
183
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
184
+
185
+ if not self.training:
186
+ with torch.no_grad():
187
+ cluster_size = encodings.sum(0)
188
+ self.all_reduce_fn(cluster_size)
189
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
190
+
191
+ if self.training and self.embedding.update:
192
+ # EMA cluster size
193
+
194
+ bins = encodings.sum(0)
195
+ self.all_reduce_fn(bins)
196
+
197
+ # self.embedding.cluster_size_ema_update(bins)
198
+ ema_inplace(self.cluster_size, bins, self.decay)
199
+
200
+ zero_mask = bins == 0
201
+ bins = bins.masked_fill(zero_mask, 1.0)
202
+
203
+ embed_sum = z_flattened.t() @ encodings
204
+ self.all_reduce_fn(embed_sum)
205
+
206
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
207
+ embed_normalized = l2norm(embed_normalized)
208
+
209
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, embed_normalized)
210
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
211
+
212
+ # compute loss for embedding
213
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
214
+
215
+ # preserve gradients
216
+ z_q = z + (z_q - z).detach()
217
+
218
+ # reshape back to match original input shape
219
+ # z_q, 'b h w c -> b c h w'
220
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
221
+ # z_q = z_q.transpose(1, 2)
222
+ return z_q, loss, encoding_indices
NatureLM/models/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Earth Species Project
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from transformers import StoppingCriteria
17
+
18
+
19
+ class StoppingCriteriaSub(StoppingCriteria):
20
+ def __init__(self, stops=[], encounters=1):
21
+ super().__init__()
22
+ self.stops = stops
23
+
24
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
25
+ for stop in self.stops:
26
+ if torch.all((stop == input_ids[0][-len(stop) :])).item():
27
+ return True
28
+
29
+ return False
NatureLM/optims.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is from https://github.com/salesforce/LAVIS/blob/main/lavis/common/optims.py
2
+
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+
8
+ from NatureLM.config import OptimizerConfig
9
+
10
+
11
+ class LinearWarmupStepLRScheduler:
12
+ def __init__(
13
+ self,
14
+ optimizer,
15
+ max_epoch,
16
+ min_lr,
17
+ init_lr,
18
+ decay_rate=1,
19
+ warmup_start_lr=-1,
20
+ warmup_steps=0,
21
+ **kwargs,
22
+ ):
23
+ self.optimizer = optimizer
24
+
25
+ self.max_epoch = max_epoch
26
+ self.min_lr = min_lr
27
+
28
+ self.decay_rate = decay_rate
29
+
30
+ self.init_lr = init_lr
31
+ self.warmup_steps = warmup_steps
32
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
33
+
34
+ def step(self, cur_epoch, cur_step):
35
+ if cur_epoch == 0:
36
+ warmup_lr_schedule(
37
+ step=cur_step,
38
+ optimizer=self.optimizer,
39
+ max_step=self.warmup_steps,
40
+ init_lr=self.warmup_start_lr,
41
+ max_lr=self.init_lr,
42
+ )
43
+ else:
44
+ step_lr_schedule(
45
+ epoch=cur_epoch,
46
+ optimizer=self.optimizer,
47
+ init_lr=self.init_lr,
48
+ min_lr=self.min_lr,
49
+ decay_rate=self.decay_rate,
50
+ )
51
+
52
+
53
+ class LinearWarmupCosineLRScheduler:
54
+ def __init__(
55
+ self,
56
+ optimizer,
57
+ max_epoch,
58
+ iters_per_epoch,
59
+ min_lr,
60
+ init_lr,
61
+ warmup_steps=0,
62
+ warmup_start_lr=-1,
63
+ **kwargs,
64
+ ):
65
+ self.optimizer = optimizer
66
+
67
+ self.max_epoch = max_epoch
68
+ self.iters_per_epoch = iters_per_epoch
69
+ self.min_lr = min_lr
70
+
71
+ self.init_lr = init_lr
72
+ self.warmup_steps = warmup_steps
73
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
74
+
75
+ def step(self, cur_epoch, cur_step):
76
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
77
+ if total_cur_step < self.warmup_steps:
78
+ warmup_lr_schedule(
79
+ step=cur_step,
80
+ optimizer=self.optimizer,
81
+ max_step=self.warmup_steps,
82
+ init_lr=self.warmup_start_lr,
83
+ max_lr=self.init_lr,
84
+ )
85
+ else:
86
+ cosine_lr_schedule(
87
+ epoch=total_cur_step,
88
+ optimizer=self.optimizer,
89
+ max_epoch=self.max_epoch * self.iters_per_epoch,
90
+ init_lr=self.init_lr,
91
+ min_lr=self.min_lr,
92
+ )
93
+
94
+
95
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
96
+ """Decay the learning rate"""
97
+ lr = (init_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * epoch / max_epoch)) + min_lr
98
+ for param_group in optimizer.param_groups:
99
+ param_group["lr"] = lr
100
+
101
+
102
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
103
+ """Warmup the learning rate"""
104
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
105
+ for param_group in optimizer.param_groups:
106
+ param_group["lr"] = lr
107
+
108
+
109
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
110
+ """Decay the learning rate"""
111
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
112
+ for param_group in optimizer.param_groups:
113
+ param_group["lr"] = lr
114
+
115
+
116
+ def get_optimizer(model, config: OptimizerConfig):
117
+ num_parameters = 0
118
+ p_wd, p_non_wd = [], []
119
+ for n, p in model.named_parameters():
120
+ if not p.requires_grad:
121
+ continue # frozen weights
122
+ print(n)
123
+ if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
124
+ p_non_wd.append(p)
125
+ else:
126
+ p_wd.append(p)
127
+ num_parameters += p.data.nelement()
128
+ logging.info("number of trainable parameters: %d" % num_parameters)
129
+ optim_params = [
130
+ {
131
+ "params": p_wd,
132
+ "weight_decay": float(config.weight_decay),
133
+ },
134
+ {"params": p_non_wd, "weight_decay": 0},
135
+ ]
136
+ beta2 = config.beta2
137
+ if config.device == "cpu":
138
+ optimizer = torch.optim.AdamW(
139
+ optim_params,
140
+ lr=float(config.init_lr),
141
+ weight_decay=float(config.weight_decay),
142
+ betas=(0.9, beta2),
143
+ )
144
+ else:
145
+ import bitsandbytes as bnb
146
+
147
+ optimizer = bnb.optim.PagedAdamW8bit(
148
+ optim_params,
149
+ lr=float(config.init_lr),
150
+ weight_decay=float(config.weight_decay),
151
+ betas=(0.9, beta2),
152
+ )
153
+
154
+ return optimizer
NatureLM/processors.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module contains the audio and text processor for NatureLM-audio inference and evaluation"""
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass, field
6
+
7
+ import numpy as np
8
+ import resampy
9
+ import soundfile as sf
10
+ import torch
11
+
12
+
13
+ @dataclass
14
+ class NatureLMAudioProcessor:
15
+ """Preprocess samples to make them ready for NatureLM-audio inference.
16
+
17
+ Arguments
18
+ ---------
19
+ naturelm_sample_rate : int
20
+ The sample rate of the NatureLM model
21
+ max_length_seconds : int
22
+ The maximum length of audio in seconds
23
+ audio_token_placeholder : str
24
+ The placeholder for the audio token in the instruction
25
+ prompt_template : str
26
+ The template for the prompt. The instruction or query from the user is inserted in the placeholder at {prompt}
27
+
28
+
29
+ Examples
30
+ --------
31
+ >>> processor = NatureLMAudioProcessor()
32
+ >>> audios = [np.random.rand(32000), np.random.rand(32000)]
33
+ >>> instructions = ["What is the weather today?", "What is the time now?"]
34
+ >>> input_sample_rates = [32000, 32000]
35
+ >>> audios, instructions = processor(audios, instructions, input_sample_rates)
36
+ >>> audios.shape == (2, 160000)
37
+ True
38
+ >>> "<Audio><AudioHere></Audio> " in instructions[0]
39
+ True
40
+ >>> "<|start_header_id|>user<|end_header_id|>" in instructions[0]
41
+ True
42
+ """
43
+
44
+ sample_rate: int = 16000
45
+ max_length_seconds: int = 10
46
+ audio_token_placeholder: str = "<Audio><AudioHere></Audio> "
47
+ prompt_template: str = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
48
+
49
+ def prepare_audio(self, audio: list[float] | np.ndarray | os.PathLike, input_sr: int = None) -> torch.Tensor:
50
+ """Prepare an audio array or file path for inference"""
51
+ if isinstance(audio, str | os.PathLike):
52
+ audio, sr = sf.read(audio)
53
+ input_sr = sr
54
+ elif isinstance(audio, list):
55
+ audio = np.array(audio)
56
+
57
+ assert isinstance(audio, np.ndarray), "Audio not a numpy array"
58
+
59
+ # Convert stereo to mono
60
+ if len(audio.shape) == 2:
61
+ # find the smaller axis as channel dim to avg over (like (2, T) or (T, 2), 2 = channel dim
62
+ axis_to_average = int(np.argmin(audio.shape))
63
+ audio = audio.mean(axis=axis_to_average)
64
+
65
+ # Resample
66
+ if input_sr is not None and input_sr != self.sample_rate:
67
+ # audio = torchaudio.functional.resample(
68
+ # torch.from_numpy(audio), orig_freq=input_sr, new_freq=self.sample_rate
69
+ # )
70
+ audio = resampy.resample(audio, input_sr, self.sample_rate)
71
+ audio = torch.from_numpy(audio.squeeze())
72
+ else:
73
+ audio = torch.from_numpy(audio)
74
+
75
+ # Truncate audio to at most max_length_seconds
76
+ audio = audio[: self.sample_rate * self.max_length_seconds]
77
+
78
+ # Pad to max_length_seconds if short
79
+ if len(audio) < self.sample_rate * self.max_length_seconds:
80
+ pad_size = self.sample_rate * self.max_length_seconds - len(audio)
81
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
82
+
83
+ # Clamp
84
+ audio = torch.clamp(audio, -1.0, 1.0)
85
+
86
+ return audio.squeeze()
87
+
88
+ def prepare_instruction(self, instruction: str) -> str:
89
+ """Add the audio token placeholder to the instruction and format it
90
+ according to the llama tokenizer.
91
+ """
92
+ if self.audio_token_placeholder not in instruction:
93
+ instruction = self.audio_token_placeholder + instruction
94
+ instruction = self.prompt_template.format(prompt=instruction.strip())
95
+
96
+ return instruction
97
+
98
+ def __call__(
99
+ self,
100
+ audios: list[list[float] | np.ndarray] | list[str | os.PathLike],
101
+ instructions: list[str],
102
+ input_sample_rates: list[int],
103
+ ) -> tuple[torch.Tensor, list[str]]:
104
+ """Prepare audios and instructions for inference
105
+
106
+ Arguments
107
+ ---------
108
+ audios : list[list[float] | np.ndarray] | list[str | os.PathLike]
109
+ The audio samples or file paths
110
+ instructions : list[str]
111
+ The instructions or queries
112
+ input_sample_rates : list[int]
113
+ The sample rates of the input audio samples
114
+
115
+ Returns
116
+ -------
117
+ tuple[torch.Tensor, list[str]]
118
+ The prepared audios and instructions
119
+ """
120
+ audios = torch.stack(
121
+ [self.prepare_audio(audio, input_sr) for audio, input_sr in zip(audios, input_sample_rates)]
122
+ )
123
+ instructions = [self.prepare_instruction(instruction) for instruction in instructions]
124
+
125
+ return audios, instructions
126
+
127
+
128
+ @dataclass
129
+ class NatureLMAudioEvalProcessor(NatureLMAudioProcessor):
130
+ """Preprocess samples to make them ready for NatureLM-audio evaluation on BEANS-Zero dataset.
131
+ This requires a few additional parameters compared to the NatureLMAudioProcessor.
132
+
133
+ Arguments
134
+ ---------
135
+ naturelm_sample_rate : int
136
+ The sample rate of the NatureLM model
137
+ max_length_seconds : int
138
+ The maximum length of audio in seconds
139
+ audio_token_placeholder : str
140
+ The placeholder for the audio token in the instruction
141
+ prompt_template : str
142
+ The template for the prompt. The instruction or query from the user is inserted in the placeholder at {prompt}
143
+
144
+ dataset_name : list[str]
145
+ The name of the dataset being processed
146
+ true_labels : list[str]
147
+ The true labels or expected outputs for the samples.
148
+ task: str
149
+ The task for the dataset. Can be 'detection', 'captioning', or 'classification'
150
+ threshold_too_many_detection_labels : int
151
+ The threshold for the number of labels in the dataset to switch to a detection prompt. Default is 8.
152
+
153
+
154
+ Examples
155
+ --------
156
+ >>> processor = NatureLMAudioEvalProcessor(task="detection", true_labels=["dog", "cat", "bird", "None", "mouse", "elephant", "lion", "tiger", "bear"])
157
+ >>> audios = [np.random.rand(32000), np.random.rand(32000)]
158
+ >>> instructions = ["What is the weather today?", "What is the time now?"]
159
+ >>> input_sample_rates = [32000, 32000]
160
+ >>> audios, instructions = processor(audios, instructions, input_sample_rates)
161
+ >>> audios.shape == (2, 160000)
162
+ True
163
+ >>> "<Audio><AudioHere></Audio> " in instructions[0]
164
+ True
165
+ >>> "<|start_header_id|>user<|end_header_id|>" in instructions[0]
166
+ True
167
+ >>> "What are the common names" in instructions[0]
168
+ True
169
+ """
170
+
171
+ dataset_name: str = "beans-zero"
172
+ true_labels: list[str] = field(default_factory=lambda _: [])
173
+ task: str = "detection"
174
+
175
+ threshold_too_many_detection_labels: int = 8
176
+
177
+ def __post_init__(self):
178
+ self.detection_prompt: str = (
179
+ "<Audio><AudioHere></Audio> What are the common names for the species in the audio, if any?"
180
+ )
181
+
182
+ # find the unique labels in the dataset
183
+ self.dataset_labels = set(self.true_labels)
184
+ if self.task == "detection":
185
+ self.dataset_labels.add("None")
186
+ if self.task == "captioning":
187
+ self.dataset_labels = set()
188
+
189
+ def prepare_instruction(self, instruction: str) -> str:
190
+ """Add the audio token placeholder to the instruction and format it"""
191
+ if self.task == "detection" and len(self.dataset_labels) > self.threshold_too_many_detection_labels:
192
+ instruction = self.detection_prompt
193
+
194
+ if self.audio_token_placeholder not in instruction:
195
+ instruction = self.audio_token_placeholder + instruction
196
+
197
+ instruction = self.prompt_template.format(prompt=instruction.strip())
198
+
199
+ return instruction
200
+
201
+
202
+ class NatureLMInferenceDataset(torch.utils.data.Dataset):
203
+ """A pytorch dataset for batched inference with NatureLM-audio
204
+
205
+ TODO: currently, if the batch contains very different prompts the model doesnt work well.
206
+
207
+ Arguments
208
+ ---------
209
+ ds : datasets.Dataset
210
+ The huggingface dataset containing the samples
211
+
212
+ Examples
213
+ --------
214
+ TODO: Add examples
215
+ """
216
+
217
+ def __init__(self, ds, processor):
218
+ self.ds = ds
219
+ self.processor = processor
220
+
221
+ def __getitem__(self, idx):
222
+ sample = self.ds[idx]
223
+ input_sample_rate = json.loads(sample["metadata"])["sample_rate"]
224
+ audio_tensor = self.processor.prepare_audio(sample["audio"], input_sample_rate)
225
+
226
+ instruction = self.processor.prepare_instruction(sample["instruction"])
227
+ return {
228
+ "raw_wav": audio_tensor,
229
+ "text": "",
230
+ "task": sample["task"],
231
+ "audio_chunk_sizes": len(audio_tensor),
232
+ "index": idx,
233
+ "id": sample["id"],
234
+ "prompt": instruction,
235
+ "label": sample["output"],
236
+ }
237
+
238
+ def __len__(self):
239
+ return len(self.ds)
240
+
241
+
242
+ def collater(samples: list[dict]) -> dict:
243
+ """Collate samples into a batch.
244
+
245
+ Samples is a list of dictionaries, each containing the following keys:
246
+ - raw_wav: a list of tensors containing the raw audio waveform
247
+ - text: a list of strings containing the text
248
+ - task: a list of strings containing the task
249
+ - id: a list of strings containing the id
250
+ - prompt: a list of strings containing the prompt
251
+ - index: a list of integers containing the index
252
+ - audio_chunk_sizes: a list of integers containing the size of each audio chunk
253
+
254
+ The indiviudal audio waveforms will be stacked along the batch dimension for easier
255
+ processing in the audio model. To keep which audio belongs to which sample, we add
256
+ the audio_chunk_sizes key to the batch dictionary.
257
+ """
258
+ raw_wav = torch.stack([s["raw_wav"] for s in samples])
259
+ paddding_mask = torch.zeros_like(raw_wav).to(torch.bool)
260
+
261
+ text = [s["text"] for s in samples]
262
+ prompt = [s["prompt"] for s in samples]
263
+ task = [s["task"] for s in samples]
264
+ id = [s["id"] for s in samples]
265
+ index = [s["index"] for s in samples]
266
+ label = [s["label"] for s in samples]
267
+
268
+ return {
269
+ "raw_wav": raw_wav,
270
+ "padding_mask": paddding_mask,
271
+ "text": text,
272
+ "task": task,
273
+ "id": id,
274
+ "prompt": prompt,
275
+ "index": index,
276
+ "audio_chunk_sizes": 1,
277
+ "label": label,
278
+ }
NatureLM/runner.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is based on https://github.com/salesforce/LAVIS/blob/main/lavis/runners/runner_base.py
2
+
3
+ import datetime
4
+ import json
5
+ import logging
6
+ import os
7
+ import time
8
+ from collections import defaultdict
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import torch.distributed
13
+ import torch.distributed as dist
14
+ import wandb
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from NatureLM.config import Config
19
+ from NatureLM.dist_utils import get_rank, get_world_size, is_dist_avail_and_initialized, is_main_process, main_process
20
+ from NatureLM.logger import MetricLogger, SmoothedValue
21
+ from NatureLM.optims import LinearWarmupCosineLRScheduler, get_optimizer
22
+ from NatureLM.task_metrics import get_task_metrics
23
+ from NatureLM.utils import get_dataloader, prepare_sample_dist
24
+
25
+
26
+ class Runner:
27
+ def __init__(self, cfg: Config, model, datasets, job_id):
28
+ self.config = cfg
29
+
30
+ # log
31
+ device = "cuda:0"
32
+ if is_main_process():
33
+ if self.config.run.wandb_enabled:
34
+ wandb.init(project="earthlm", config=self.config.model_dump())
35
+ else:
36
+ wandb.init(mode="disabled")
37
+
38
+ if "LOCAL_RANK" in os.environ:
39
+ device = int(os.environ["LOCAL_RANK"])
40
+ else:
41
+ device = self.config.run.device
42
+ print(f"device is {device} could have been {self.config.run.device}")
43
+ self.output_dir = Path(self.config.run.output_dir) / job_id
44
+ self.output_dir.mkdir(parents=True, exist_ok=True)
45
+ self.log_writter = SummaryWriter(self.output_dir)
46
+
47
+ # settings
48
+ self.device = torch.device(device)
49
+ self.use_distributed = self.config.run.use_distributed
50
+ self.start_epoch = 0
51
+ self.max_epoch = self.config.run.optims.max_epoch
52
+ self.evaluate_only = self.config.run.evaluate
53
+ self.cuda_enabled = self.device.type == "cuda"
54
+
55
+ # test prompt
56
+ self.prompt_template = self.config.model.prompt_template
57
+
58
+ # model
59
+ self._model = model
60
+ torch.nn.SyncBatchNorm.convert_sync_batchnorm(self._model)
61
+ self._model.to(self.device)
62
+ if self.use_distributed:
63
+ self.model = DDP(
64
+ self._model,
65
+ find_unused_parameters=True,
66
+ static_graph=False,
67
+ device_ids=[self.device],
68
+ )
69
+ else:
70
+ self.model = self._model
71
+
72
+ # dataloaders
73
+ self.train_loader = get_dataloader(
74
+ datasets["train"],
75
+ self.config.run,
76
+ is_train=True,
77
+ use_distributed=self.use_distributed,
78
+ )
79
+ self.valid_loader = get_dataloader(
80
+ datasets["valid"],
81
+ self.config.run,
82
+ is_train=False,
83
+ use_distributed=self.use_distributed,
84
+ )
85
+ self.test_loader = get_dataloader(
86
+ datasets["test"],
87
+ self.config.run,
88
+ is_train=False,
89
+ use_distributed=self.use_distributed,
90
+ )
91
+
92
+ # scaler
93
+ self.use_amp = self.config.run.amp
94
+ if self.use_amp:
95
+ self.scaler = torch.cuda.amp.GradScaler()
96
+ else:
97
+ self.scaler = None
98
+
99
+ # optimizer & scheduler
100
+ self.iters_per_epoch = (
101
+ len(self.train_loader) if self.config.run.epoch_based else self.config.run.iters_per_epoch
102
+ )
103
+ self.optimizer = get_optimizer(self.model, self.config.run.optims)
104
+ self.scheduler = LinearWarmupCosineLRScheduler(
105
+ self.optimizer,
106
+ max_epoch=self.max_epoch,
107
+ iters_per_epoch=self.iters_per_epoch,
108
+ min_lr=self.config.run.optims.min_lr,
109
+ init_lr=self.config.run.optims.init_lr,
110
+ warmup_steps=self.config.run.optims.warmup_steps,
111
+ warmup_start_lr=self.config.run.optims.warmup_start_lr,
112
+ )
113
+
114
+ #### augmentations
115
+ # self.rng = random.Random(self.config.run.seed)
116
+ # self.rngnp = np.random.default_rng(seed=self.config.run.seed)
117
+ # self.rngth = torch.Generator(device=args.device)
118
+ # self.rngth.manual_seed(self.config.run.seed)
119
+ # augments = []
120
+ # if self.config.run.augmentations.flip:
121
+ # augments.append(augmentations.Flip(self.config.run.augmentations.flip, rngth=self.rngth, seed=self.config.run.seed))
122
+ # if self.config.run.augmentations.bandmask:
123
+ # augments.append(augmentations.BandMask(self.config.run.augmentations.bandmask, sample_rate=args.sample_rate, rng=self.rng, seed=self.config.run.seed))
124
+ # if self.config.run.augmentations.revecho:
125
+ # augments.append(
126
+ # augmentations.RevEcho(proba=self.config.run.augmentations.revecho,rng=self.rng,seed=self.config.run.seed))
127
+ # self.augment = torch.nn.Sequential(*augments)
128
+
129
+ self.log_config()
130
+
131
+ def unwrap_dist_model(self, model):
132
+ if self.use_distributed:
133
+ return model.module
134
+ else:
135
+ return model
136
+
137
+ def train_epoch(self, epoch):
138
+ self.model.train()
139
+
140
+ metric_logger = MetricLogger(delimiter=" ")
141
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
142
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
143
+
144
+ logging.info("Start training epoch {}, {} iters per inner epoch.".format(epoch, self.iters_per_epoch))
145
+ header = "Train: data epoch: [{}]".format(epoch)
146
+
147
+ # Get gradient clipping parameters from config
148
+ clip_grad_norm = self.config.run.optims.max_grad_norm
149
+ clip_grad_value = self.config.run.optims.max_grad_value
150
+
151
+ for i in metric_logger.log_every(
152
+ range(self.iters_per_epoch),
153
+ self.config.run.log_freq,
154
+ header=header,
155
+ logger=self.log_writter,
156
+ start_step=epoch * self.iters_per_epoch,
157
+ ):
158
+ if i >= self.iters_per_epoch:
159
+ break
160
+
161
+ samples = next(self.train_loader)
162
+
163
+ samples = prepare_sample_dist(samples, self.device)
164
+
165
+ #### augmentation
166
+ # if False:
167
+ # samples = self.augment(samples)
168
+
169
+ self.scheduler.step(cur_epoch=epoch, cur_step=i)
170
+
171
+ with torch.autocast(self.device.type, enabled=self.use_amp, dtype=torch.bfloat16):
172
+ loss = self.model(samples)["loss"]
173
+ if torch.isnan(loss):
174
+ print("loss nan", samples)
175
+ # continue
176
+
177
+ if self.use_amp and self.scaler:
178
+ self.scaler.scale(loss).backward()
179
+ else:
180
+ loss.backward()
181
+
182
+ # Apply gradient clipping
183
+ if clip_grad_norm is not None:
184
+ if self.use_amp and self.scaler:
185
+ self.scaler.unscale_(self.optimizer)
186
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad_norm)
187
+ if clip_grad_value is not None:
188
+ if self.use_amp and self.scaler:
189
+ self.scaler.unscale_(self.optimizer)
190
+ torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=clip_grad_value)
191
+
192
+ if (i + 1) % self.config.run.accum_grad_iters == 0:
193
+ if self.use_amp and self.scaler:
194
+ self.scaler.step(self.optimizer)
195
+ self.scaler.update()
196
+ else:
197
+ self.optimizer.step()
198
+ self.optimizer.zero_grad()
199
+
200
+ metric_logger.update(loss=loss.item())
201
+ metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
202
+
203
+ metric_logger.synchronize_between_processes()
204
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
205
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
206
+
207
+ @torch.no_grad()
208
+ def valid_epoch(self, epoch, split, decode=True, save_json=False, decode_ratio=1.0):
209
+ """
210
+ Decode = True will lead to calculation of custom metrics which are based on text.
211
+ decode_ratio controls the percentage of batches which will have custom metrics computed,
212
+ a speed trade-off due to the cost of the 'generate' method.
213
+ """
214
+ model = self.unwrap_dist_model(self.model)
215
+ model.eval()
216
+
217
+ dataloader = getattr(self, split + "_loader", None)
218
+ assert dataloader is not None, f"{split}_loader does not exist."
219
+
220
+ metric_logger = MetricLogger(delimiter=" ")
221
+ header = f"Eval: data epoch: [{epoch}]"
222
+
223
+ results_per_task = defaultdict(list) # Store results per task
224
+ overall_results = [] # Store all results for overall metrics
225
+
226
+ # Calculate N based on decode_ratio
227
+ if decode_ratio <= 0.0:
228
+ N = float("inf") # Effectively never run generate
229
+ elif decode_ratio >= 1.0:
230
+ N = 1 # Run generate every batch
231
+ else:
232
+ N = max(int(1 / decode_ratio), 1) # Ensure N is at least 1
233
+
234
+ batch_idx = 0
235
+
236
+ # Initialize overall metrics
237
+ overall_res = {
238
+ "loss": torch.tensor(0.0, device=self.device),
239
+ "correct": torch.tensor(0.0, device=self.device),
240
+ "total": torch.tensor(0.0, device=self.device),
241
+ }
242
+
243
+ # Initialize per-task metrics
244
+ per_task_res = defaultdict(
245
+ lambda: {
246
+ "loss": torch.tensor(0.0, device=self.device),
247
+ "correct": torch.tensor(0.0, device=self.device),
248
+ "total": torch.tensor(0.0, device=self.device),
249
+ "n_sample": 0,
250
+ "predicted_texts": [],
251
+ "gold_texts": [],
252
+ }
253
+ )
254
+
255
+ for samples in metric_logger.log_every(dataloader, self.config.run.log_freq, header=header):
256
+ samples = prepare_sample_dist(samples, self.device)
257
+
258
+ with torch.autocast(self.device.type, enabled=self.use_amp):
259
+ forward_result = model(samples, verbose=True)
260
+
261
+ # Extract batch-level loss and correct counts
262
+ batch_loss = forward_result.get("loss", torch.tensor(0.0, device=self.device))
263
+ batch_correct = forward_result.get("correct", torch.tensor(0.0, device=self.device))
264
+ batch_total = forward_result.get("total", torch.tensor(1.0, device=self.device))
265
+
266
+ batch_size = len(samples["id"])
267
+
268
+ # Update overall metrics with batch-level values
269
+ overall_res["loss"] += batch_loss.detach()
270
+ overall_res["correct"] += batch_correct.detach()
271
+ overall_res["total"] += batch_total.detach()
272
+
273
+ # Decide whether to run generate based on decode_ratio
274
+ if decode and (batch_idx % N == 0):
275
+ prompts = samples.get("prompt", None)
276
+ try:
277
+ generated_texts = model.generate(samples, self.config.generate, prompts=prompts)
278
+ except Exception as e:
279
+ print("error in generation", e)
280
+ generated_texts = [None] * batch_size
281
+ else:
282
+ generated_texts = [None] * batch_size # Placeholder if not decoding
283
+
284
+ # Process per-sample data for per-task metrics and result saving
285
+ for i in range(batch_size):
286
+ task = samples["task"][i]
287
+
288
+ # Collect per-task batch-level metrics
289
+ per_task_res[task]["loss"] += batch_loss.detach()
290
+ per_task_res[task]["correct"] += batch_correct.detach()
291
+ per_task_res[task]["total"] += batch_total.detach()
292
+ per_task_res[task]["n_sample"] += 1
293
+
294
+ res = {
295
+ "id": samples["id"][i],
296
+ "ground_truth": samples["text"][i], # Gold label from dataloader
297
+ "task": task,
298
+ "predicted_text": generated_texts[i],
299
+ }
300
+
301
+ if decode and generated_texts[i] is not None:
302
+ res["prompt"] = samples.get("prompt", [None])[i]
303
+
304
+ results_per_task[task].append(res)
305
+ overall_results.append(res)
306
+
307
+ # Collect texts for custom metrics
308
+ if generated_texts[i] is not None:
309
+ per_task_res[task]["predicted_texts"].append(generated_texts[i])
310
+ per_task_res[task]["gold_texts"].append(samples["text"][i])
311
+
312
+ batch_idx += 1 # Increment batch index
313
+
314
+ if save_json:
315
+ for task, task_results in results_per_task.items():
316
+ self.save_result(task_results, self.output_dir, f"eval_{split}_{task}_epoch_{epoch}")
317
+ # Optionally save overall results
318
+ self.save_result(overall_results, self.output_dir, f"eval_{split}_epoch_{epoch}")
319
+
320
+ # Synchronize metrics across processes if in distributed mode
321
+ if is_dist_avail_and_initialized():
322
+ for key in overall_res:
323
+ dist.all_reduce(overall_res[key])
324
+
325
+ overall_ret = {
326
+ "loss": (overall_res["loss"] / batch_idx).item(),
327
+ "agg_metrics": (overall_res["correct"] / overall_res["total"]).item(),
328
+ }
329
+
330
+ if is_main_process():
331
+ # Log overall metrics
332
+ wandb.log(
333
+ {
334
+ f"{split}_loss": overall_ret["loss"],
335
+ f"{split}_accuracy": overall_ret["agg_metrics"],
336
+ "epoch": epoch,
337
+ }
338
+ )
339
+
340
+ # Compute and log per-task metrics
341
+ for task, res in per_task_res.items():
342
+ if "caption-none" in task:
343
+ continue
344
+
345
+ if self.use_distributed:
346
+ print(f"Rank {dist.get_rank()}, task={task}, ")
347
+
348
+ print(
349
+ f"loss={res['loss'].shape, res['loss'].dtype}, "
350
+ f"correct={res['correct'].shape, res['correct'].dtype}, "
351
+ f"total={res['total'].shape, res['total'].dtype}, "
352
+ f"n_sample={res['n_sample']}"
353
+ )
354
+
355
+ # Synchronize metrics across processes if in distributed mode
356
+ if is_dist_avail_and_initialized():
357
+ dist.all_reduce(res["loss"])
358
+ dist.all_reduce(res["correct"])
359
+ dist.all_reduce(res["total"])
360
+ dist.all_reduce(torch.tensor(res["n_sample"], device=self.device))
361
+
362
+ ret = {
363
+ "loss": (res["loss"] / res["n_sample"]).item(),
364
+ "agg_metrics": (res["correct"] / res["total"]).item(),
365
+ }
366
+
367
+ if is_main_process():
368
+ # Log per-task metrics
369
+ wandb.log(
370
+ {
371
+ f"{split}_{task}_loss": ret["loss"],
372
+ f"{split}_{task}_accuracy": ret["agg_metrics"],
373
+ "epoch": epoch,
374
+ }
375
+ )
376
+
377
+ # Get and compute custom metrics for this task
378
+ metrics_list = get_task_metrics(task)
379
+ predicted_texts = res["predicted_texts"]
380
+ gold_texts = res["gold_texts"]
381
+ for metric in metrics_list:
382
+ if predicted_texts and gold_texts:
383
+ metric_value = metric.compute_metric(predicted_texts, gold_texts)
384
+ metric_name = metric.__class__.__name__
385
+ wandb.log(
386
+ {
387
+ f"{split}_{task}_{metric_name}": metric_value,
388
+ "epoch": epoch,
389
+ }
390
+ )
391
+ return overall_ret # Return overall metrics
392
+
393
+ def save_result(self, result, result_dir, filename):
394
+ result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, get_rank()))
395
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
396
+
397
+ try:
398
+ json.dump(result, open(result_file, "w"), ensure_ascii=False)
399
+ except Exception as e:
400
+ logging.warning(f"Error saving {result_file}. Error: {e}")
401
+ json.dump(result, open(result_file, "w", encoding="utf-8"), ensure_ascii=False)
402
+
403
+ # if is_dist_avail_and_initialized():
404
+ # dist.barrier()
405
+
406
+ if is_main_process():
407
+ logging.info("rank %d starts merging results." % get_rank())
408
+ result = []
409
+
410
+ for rank in range(get_world_size()):
411
+ result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
412
+ try:
413
+ res = json.load(open(result_file, "r"))
414
+ except Exception as e:
415
+ logging.warning(f"Error reading {result_file}. Error: {e}")
416
+ res = json.load(open(result_file, "r", encoding="utf-8"))
417
+ result += res
418
+
419
+ try:
420
+ json.dump(result, open(final_result_file, "w"), ensure_ascii=False)
421
+ except Exception as e:
422
+ logging.warning(f"Error saving {final_result_file}. Error: {e}")
423
+ json.dump(
424
+ result,
425
+ open(final_result_file, "w", encoding="utf-8"),
426
+ ensure_ascii=False,
427
+ )
428
+
429
+ print("result file saved to %s" % final_result_file)
430
+
431
+ def train(self):
432
+ start_time = time.time()
433
+ best_agg_metric = 0
434
+ best_epoch = 0
435
+
436
+ for cur_epoch in range(self.start_epoch, self.max_epoch):
437
+ if self.evaluate_only:
438
+ break
439
+
440
+ # training phase
441
+ logging.info("Training Phase")
442
+ train_stats = self.train_epoch(cur_epoch)
443
+ self.log_stats(train_stats, split_name="train")
444
+
445
+ # validating phase
446
+ logging.info("Validating Phase")
447
+ valid_log = self.valid_epoch(
448
+ cur_epoch,
449
+ "valid",
450
+ decode=self.config.run.custom_metrics,
451
+ save_json=False,
452
+ decode_ratio=self.config.run.decode_ratio,
453
+ )
454
+ if valid_log is not None:
455
+ if is_main_process():
456
+ agg_metrics = valid_log["agg_metrics"]
457
+ if agg_metrics > best_agg_metric:
458
+ best_agg_metric = agg_metrics
459
+ best_epoch = cur_epoch
460
+ self.save_checkpoint(cur_epoch, is_best=True)
461
+
462
+ valid_log.update({"best_epoch": best_epoch})
463
+ self.log_stats(valid_log, split_name="valid")
464
+ self.save_checkpoint(cur_epoch, is_best=False)
465
+
466
+ # if self.use_distributed:
467
+ # dist.barrier()
468
+
469
+ # testing phase
470
+ if self.evaluate_only:
471
+ self.valid_epoch("best", "test", decode=True, save_json=True)
472
+
473
+ total_time = time.time() - start_time
474
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
475
+ logging.info("Training time {}".format(total_time_str))
476
+
477
+ @main_process
478
+ def log_config(self):
479
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
480
+ f.write(json.dumps(self.config.model_dump_json(), indent=4) + "\n")
481
+
482
+ @main_process
483
+ def log_stats(self, stats, split_name):
484
+ if isinstance(stats, dict):
485
+ log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
486
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
487
+ f.write(json.dumps(log_stats) + "\n")
488
+ elif isinstance(stats, list):
489
+ pass
490
+
491
+ @main_process
492
+ def save_checkpoint(self, cur_epoch, is_best=False):
493
+ """
494
+ Save the checkpoint at the current epoch.
495
+ """
496
+ model_no_ddp = self.unwrap_dist_model(self.model)
497
+ param_grad_dic = {k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()}
498
+ state_dict = model_no_ddp.state_dict()
499
+ for k in list(state_dict.keys()):
500
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
501
+ # delete parameters that do not require gradient
502
+ del state_dict[k]
503
+ save_obj = {
504
+ "model": state_dict,
505
+ "optimizer": self.optimizer.state_dict(),
506
+ "config": dict(self.config),
507
+ "scaler": self.scaler.state_dict() if self.scaler else None,
508
+ "epoch": cur_epoch,
509
+ }
510
+ save_to = os.path.join(
511
+ self.output_dir,
512
+ "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
513
+ )
514
+ logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
515
+ torch.save(save_obj, save_to)
NatureLM/storage_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Union
5
+
6
+ import cloudpathlib
7
+ from google.cloud.storage.client import Client
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def is_gcs_path(path: Union[str, os.PathLike]) -> bool:
13
+ return str(path).startswith("gs://")
14
+
15
+
16
+ @lru_cache(maxsize=1)
17
+ def _get_client():
18
+ return cloudpathlib.GSClient(storage_client=Client())
19
+
20
+
21
+ try:
22
+ _gcp_storage_client = _get_client()
23
+ except Exception:
24
+ logger.warning("Failed to initialize GCS client." "Training wont be able to use GSPath or R2Path without a client.")
25
+ _gcp_storage_client = None
26
+
NatureLM/task_metric_utils.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from DCASE 2021 Task 5 evaluation source code
2
+ # https://github.com/c4dm/dcase-few-shot-bioacoustic
3
+ # MIT License
4
+
5
+ import mir_eval
6
+ import numpy as np
7
+ import scipy
8
+
9
+
10
+ def fast_intersect(ref, est):
11
+ """Find all intersections between reference events and estimated events (fast).
12
+ Best-case complexity: O(N log N + M log M) where N=length(ref) and M=length(est)
13
+ Parameters
14
+ ----------
15
+ ref: np.ndarray [shape=(2, n)], real-valued
16
+ Array of reference events. Each column is an event.
17
+ The first row denotes onset times and the second row denotes offset times.
18
+ est: np.ndarray [shape=(2, m)], real-valued
19
+ Array of estimated events. Each column is an event.
20
+ The first row denotes onset times and the second row denotes offset times.
21
+ Returns
22
+ -------
23
+ matches: list of sets, length n, integer-valued
24
+ Property: matches[i] contains the set of all indices j such that
25
+ (ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
26
+ """
27
+ ref_on_argsort = np.argsort(ref[0, :])
28
+ ref_off_argsort = np.argsort(ref[1, :])
29
+
30
+ est_on_argsort = np.argsort(est[0, :])
31
+ est_off_argsort = np.argsort(est[1, :])
32
+
33
+ est_on_maxindex = est.shape[1]
34
+ est_off_minindex = 0
35
+ estref_matches = [set()] * ref.shape[1]
36
+ refest_matches = [set()] * ref.shape[1]
37
+ for ref_id in range(ref.shape[1]):
38
+ ref_onset = ref[0, ref_on_argsort[ref_id]]
39
+ est_off_sorted = est[1, est_off_argsort[est_off_minindex:]]
40
+ search_result = np.searchsorted(est_off_sorted, ref_onset, side="left")
41
+ est_off_minindex += search_result
42
+ refest_match = est_off_argsort[est_off_minindex:]
43
+ refest_matches[ref_on_argsort[ref_id]] = set(refest_match)
44
+
45
+ ref_offset = ref[1, ref_off_argsort[-1 - ref_id]]
46
+ est_on_sorted = est[0, est_on_argsort[: (1 + est_on_maxindex)]]
47
+ search_result = np.searchsorted(est_on_sorted, ref_offset, side="right")
48
+ est_on_maxindex = search_result - 1
49
+ estref_match = est_on_argsort[: (1 + est_on_maxindex)]
50
+ estref_matches[ref_off_argsort[-1 - ref_id]] = set(estref_match)
51
+
52
+ zip_iterator = zip(refest_matches, estref_matches)
53
+ matches = [x.intersection(y) for (x, y) in zip_iterator]
54
+ return matches
55
+
56
+
57
+ def iou(ref, est, method="fast"):
58
+ """Compute pairwise "intersection over union" (IOU) metric between reference
59
+ events and estimated events.
60
+ Let us denote by a_i and b_i the onset and offset of reference event i.
61
+ Let us denote by u_j and v_j the onset and offset of estimated event j.
62
+ The IOU between events i and j is defined as
63
+ (min(b_i, v_j)-max(a_i, u_j)) / (max(b_i, v_j)-min(a_i, u_j))
64
+ if the events are non-disjoint, and equal to zero otherwise.
65
+ Parameters
66
+ ----------
67
+ ref: np.ndarray [shape=(2, n)], real-valued
68
+ Array of reference events. Each column is an event.
69
+ The first row denotes onset times and the second row denotes offset times.
70
+ est: np.ndarray [shape=(2, m)], real-valued
71
+ Array of estimated events. Each column is an event.
72
+ The first row denotes onset times and the second row denotes offset times.
73
+ method: str, optional.
74
+ If "fast" (default), computes pairwise intersections via a custom
75
+ dynamic programming algorithm, see fast_intersect.
76
+ If "slow", computes pairwise intersections via bruteforce quadratic
77
+ search, see slow_intersect.
78
+ Returns
79
+ -------
80
+ S: scipy.sparse.dok.dok_matrix, real-valued
81
+ Sparse 2-D matrix. S[i,j] contains the IOU between ref[i] and est[j]
82
+ if these events are non-disjoint and zero otherwise.
83
+ """
84
+ n_refs = ref.shape[1]
85
+ n_ests = est.shape[1]
86
+ S = scipy.sparse.dok_matrix((n_refs, n_ests))
87
+
88
+ if method == "fast":
89
+ matches = fast_intersect(ref, est)
90
+ elif method == "slow":
91
+ matches = slow_intersect(ref, est)
92
+
93
+ for ref_id in range(n_refs):
94
+ matching_ests = matches[ref_id]
95
+ ref_on = ref[0, ref_id]
96
+ ref_off = ref[1, ref_id]
97
+
98
+ for matching_est_id in matching_ests:
99
+ est_on = est[0, matching_est_id]
100
+ est_off = est[1, matching_est_id]
101
+ intersection = min(ref_off, est_off) - max(ref_on, est_on)
102
+ union = max(ref_off, est_off) - min(ref_on, est_on)
103
+ intersection_over_union = intersection / union
104
+ S[ref_id, matching_est_id] = intersection_over_union
105
+
106
+ return S
107
+
108
+ def compute_intersection(ref, est, method="fast"):
109
+ """Compute pairwise intersection between reference
110
+ events and estimated events.
111
+ Let us denote by a_i and b_i the onset and offset of reference event i.
112
+ Let us denote by u_j and v_j the onset and offset of estimated event j.
113
+ The Intersection between events i and j is defined as
114
+ (min(b_i, v_j)-max(a_i, u_j))
115
+ if the events are non-disjoint, and equal to zero otherwise.
116
+ Parameters
117
+ ----------
118
+ ref: np.ndarray [shape=(2, n)], real-valued
119
+ Array of reference events. Each column is an event.
120
+ The first row denotes onset times and the second row denotes offset times.
121
+ est: np.ndarray [shape=(2, m)], real-valued
122
+ Array of estimated events. Each column is an event.
123
+ The first row denotes onset times and the second row denotes offset times.
124
+ method: str, optional.
125
+ If "fast" (default), computes pairwise intersections via a custom
126
+ dynamic programming algorithm, see fast_intersect.
127
+ If "slow", computes pairwise intersections via bruteforce quadratic
128
+ search, see slow_intersect.
129
+ Returns
130
+ -------
131
+ S: scipy.sparse.dok.dok_matrix, real-valued
132
+ Sparse 2-D matrix. S[i,j] contains the Intersection between ref[i] and est[j]
133
+ if these events are non-disjoint and zero otherwise.
134
+ """
135
+ n_refs = ref.shape[1]
136
+ n_ests = est.shape[1]
137
+ S = scipy.sparse.dok_matrix((n_refs, n_ests))
138
+
139
+ if method == "fast":
140
+ matches = fast_intersect(ref, est)
141
+ elif method == "slow":
142
+ matches = slow_intersect(ref, est)
143
+
144
+ for ref_id in range(n_refs):
145
+ matching_ests = matches[ref_id]
146
+ ref_on = ref[0, ref_id]
147
+ ref_off = ref[1, ref_id]
148
+
149
+ for matching_est_id in matching_ests:
150
+ est_on = est[0, matching_est_id]
151
+ est_off = est[1, matching_est_id]
152
+ intersection = min(ref_off, est_off) - max(ref_on, est_on)
153
+ # union = max(ref_off, est_off) - min(ref_on, est_on)
154
+ # intersection_over_union = intersection / union
155
+ S[ref_id, matching_est_id] = intersection #_over_union
156
+
157
+ return S
158
+
159
+
160
+ def match_events(ref, est, min_iou=0.0, method="fast"):
161
+ """
162
+ Compute a maximum matching between reference and estimated event times,
163
+ subject to a criterion of minimum intersection-over-union (IOU).
164
+ Given two lists of events ``ref`` (reference) and ``est`` (estimated),
165
+ we seek the largest set of correspondences ``(ref[i], est[j])`` such that
166
+ ``iou(ref[i], est[j]) <= min_iou``
167
+ and such that each ``ref[i]`` and ``est[j]`` is matched at most once.
168
+ This function is strongly inspired by mir_eval.onset.util.match_events.
169
+ It relies on mir_eval's implementation of the Hopcroft-Karp algorithm from
170
+ maximum bipartite graph matching. However, one important difference is that
171
+ mir_eval's distance function relies purely on onset times, whereas this function
172
+ considers both onset times and offset times to compute the IOU metric between
173
+ reference events and estimated events.
174
+ Parameters
175
+ ----------
176
+ ref: np.ndarray [shape=(2, n)], real-valued
177
+ Array of reference events. Each column is an event.
178
+ The first row denotes onset times and the second row denotes offset times.
179
+ est: np.ndarray [shape=(2, m)], real-valued
180
+ Array of estimated events. Each column is an event.
181
+ The first row denotes onset times and the second row denotes offset times.
182
+ min_iou: real number in [0, 1). Default: 0.
183
+ Threshold for minimum amount of intersection over union (IOU) to match
184
+ any two events. See the iou method for implementation details.
185
+ method: str, optional.
186
+ If "fast" (default), computes pairwise intersections via a custom
187
+ dynamic programming algorithm, see fast_intersect.
188
+ If "slow", computes pairwise intersections via bruteforce quadratic
189
+ search, see slow_intersect.
190
+ Returns
191
+ -------
192
+ matching : list of tuples
193
+ Every tuple corresponds to a match between one reference event and
194
+ one estimated event.
195
+ ``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``.
196
+ Note that all values i and j appear at most once in the list.
197
+ """
198
+
199
+ # Intersect reference events and estimated events
200
+ S = iou(ref, est, method=method)
201
+
202
+ # Threshold intersection-over-union (IOU) ratio
203
+ S_bool = scipy.sparse.dok_matrix(S > min_iou)
204
+ hits = S_bool.keys()
205
+
206
+ # Construct the bipartite graph
207
+ G = {}
208
+ for ref_i, est_i in hits:
209
+ if est_i not in G:
210
+ G[est_i] = []
211
+ G[est_i].append(ref_i)
212
+
213
+ # Apply Hopcroft-Karp algorithm (from mir_eval package)
214
+ # to obtain maximum bipartite graph matching
215
+ matching = sorted(mir_eval.util._bipartite_match(G).items())
216
+ return matching
217
+
218
+
219
+ def slow_intersect(ref, est):
220
+ """Find all intersections between reference events and estimated events (slow).
221
+ Best-case complexity: O(N*M) where N=ref.shape[1] and M=est.shape[1]
222
+ Parameters
223
+ ----------
224
+ ref: np.ndarray [shape=(2, n)], real-valued
225
+ Array of reference events. Each column is an event.
226
+ The first row denotes onset times and the second row denotes offset times.
227
+ est: np.ndarray [shape=(2, m)], real-valued
228
+ Array of estimated events. Each column is an event.
229
+ The first row denotes onset times and the second row denotes offset times.
230
+ Returns
231
+ -------
232
+ matches: list of sets, length n, integer-valued
233
+ Property: matches[i] contains the set of all indices j such that
234
+ (ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
235
+ """
236
+ matches = []
237
+ for i in range(ref.shape[1]):
238
+ matches.append(
239
+ set(
240
+ [
241
+ j
242
+ for j in range(est.shape[1])
243
+ if ((ref[0, i] <= est[1, j]) and (ref[1, i] >= est[0, j]))
244
+ ]
245
+ )
246
+ )
247
+ return
248
+
249
+
250
+ def frames_to_st_dict(x, sr=16000):
251
+ # x : Tensor of shape (batch, time) or (time,). Entries are 2 (POS), 1 (UNK), and 0 (NEG).
252
+ # returns a list of dicts {"Begin Time (s)" : [...], "End Time (s)" : [...], "Annotation" : [...]} if batch dim exists, or a single dict
253
+
254
+ if len(x.size()) == 2:
255
+ outs = []
256
+ for i in range(x.size(0)):
257
+ x_sub = x[i,:]
258
+ outs.append(_frames_to_st_dict_single(x_sub, sr=sr))
259
+ return outs
260
+ else:
261
+ return _frames_to_st_dict_single(x, sr=sr)
262
+
263
+ def _frames_to_st_dict_single(x, sr=16000):
264
+ d = {"Begin Time (s)" : [], "End Time (s)" : [], "Annotation" : []}
265
+
266
+ for label_i in [1,2]:
267
+
268
+ labels = x.numpy() == label_i # POS : 2, UNK : 1, NEG : 0
269
+
270
+ starts = np.where((~labels[:-1]) & (labels[1:]))[0] + 1
271
+ if labels[0]:
272
+ starts = np.insert(starts, 0, 0)
273
+
274
+ ends = np.where((labels[:-1]) & (~labels[1:]))[0] + 1
275
+ if labels[-1]:
276
+ ends = np.append(ends, len(labels))
277
+
278
+ for start, end in zip(starts, ends):
279
+ d["Begin Time (s)"].append(start/sr)
280
+ d["End Time (s)"].append(end/sr)
281
+ d["Annotation"].append("POS" if label_i == 2 else "UNK")
282
+
283
+ return d
NatureLM/task_metrics.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from NatureLM.task_metric_utils import match_events
8
+
9
+ # Assume the following functions are imported from the reference implementations:
10
+ # - match_events
11
+ # - iou
12
+ # - fast_intersect
13
+ # - slow_intersect
14
+ # - compute_intersection
15
+
16
+
17
+ class Metric(ABC):
18
+ @abstractmethod
19
+ def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
20
+ pass
21
+
22
+
23
+ class ExactAccuracy(Metric):
24
+ """Exact-match accuracy metric."""
25
+
26
+ def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
27
+ predicted_texts = [pt.lower().strip() for pt in predicted_texts]
28
+ gold_texts = [gt.lower().strip() for gt in gold_texts]
29
+ correct = sum(p == g for p, g in zip(predicted_texts, gold_texts))
30
+ return correct / len(gold_texts) if gold_texts else 0.0
31
+
32
+
33
+ class FewShot(Metric):
34
+ """Few-shot learning metric based on event matching using IoU."""
35
+
36
+ def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
37
+ # Initialize counts
38
+ total_TP = 0
39
+ total_FP = 0
40
+ total_FN = 0
41
+
42
+ for pred_text, gold_text in zip(predicted_texts, gold_texts):
43
+ # Extract events from texts
44
+ pred_events = parse_timestamps_from_text(pred_text)
45
+ gold_events = parse_timestamps_from_text(gold_text)
46
+
47
+ # Convert events to numpy arrays for match_events function
48
+ # Each event is (start_time, end_time), need to transpose to shape (2, n)
49
+ pred_array = np.array(pred_events).T if pred_events else np.empty((2, 0))
50
+ gold_array = np.array(gold_events).T if gold_events else np.empty((2, 0))
51
+
52
+ # Use match_events function from the reference implementation
53
+ matches = match_events(gold_array, pred_array, min_iou=0.5, method="fast")
54
+
55
+ TP = len(matches)
56
+ FP = len(pred_events) - TP
57
+ FN = len(gold_events) - TP
58
+
59
+ total_TP += TP
60
+ total_FP += FP
61
+ total_FN += FN
62
+
63
+ # Compute precision, recall, and F1 score
64
+ precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
65
+ recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0
66
+ f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
67
+
68
+ return f1_score
69
+
70
+
71
+ class NoneAccuracy(Metric):
72
+ """Accuracy for cases where 'None' is the correct answer."""
73
+
74
+ def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
75
+ # Normalize texts
76
+ predicted_texts = [pt.lower().strip() for pt in predicted_texts]
77
+ gold_texts = [gt.lower().strip() for gt in gold_texts]
78
+ # Filter indices where gold_text is 'none'
79
+ indices = [i for i, gt in enumerate(gold_texts) if gt == "none"]
80
+ if not indices:
81
+ return 0.0 # No 'None' cases in gold_texts
82
+ correct = sum(predicted_texts[i] == "none" for i in indices)
83
+ return correct / len(indices)
84
+
85
+
86
+ class MultipleSpeciesAccuracy(Metric):
87
+ """Accuracy for cases where the correct answer has at least one comma (multiple species)."""
88
+
89
+ def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
90
+ # Normalize texts
91
+ predicted_texts = [pt.lower().strip() for pt in predicted_texts]
92
+ gold_texts = [gt.lower().strip() for gt in gold_texts]
93
+ # Filter indices where gold_text contains at least one comma
94
+ indices = [i for i, gt in enumerate(gold_texts) if "," in gt]
95
+ if not indices:
96
+ return 0.0 # No multiple-species cases in gold_texts
97
+ correct = sum(predicted_texts[i] == gold_texts[i] for i in indices)
98
+ return correct / len(indices)
99
+
100
+
101
+ def get_task_metrics(task: str) -> List[Metric]:
102
+ """Get a list of metric instances appropriate for the given task."""
103
+ all_metrics = []
104
+ metrics_dict = {}
105
+
106
+ if "classification" in task:
107
+ metrics_dict["ExactAccuracy"] = ExactAccuracy()
108
+ if "fewshot" in task:
109
+ metrics_dict["FewShot"] = FewShot()
110
+ if "detection" in task:
111
+ metrics_dict["ExactAccuracy"] = ExactAccuracy() # Ensures no duplicate
112
+ metrics_dict["NoneAccuracy"] = NoneAccuracy()
113
+ metrics_dict["MultipleSpeciesAccuracy"] = MultipleSpeciesAccuracy()
114
+
115
+ all_metrics = list(metrics_dict.values())
116
+ return all_metrics
117
+
118
+
119
+ def parse_timestamps_from_text(text: str) -> List[Tuple[float, float]]:
120
+ """
121
+ Function to parse timestamps from text.
122
+ Extracts timestamps in the format "start-end" where start and end are floats.
123
+ """
124
+ # Regular expression to extract timestamps in the format "start-end"
125
+ pattern = r"(\d+\.\d+)-(\d+\.\d+)"
126
+ matches = re.findall(pattern, text)
127
+ events = [(float(start), float(end)) for start, end in matches]
128
+ return events
NatureLM/utils.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Earth Species Project
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import time
18
+ from datetime import datetime
19
+ from pathlib import Path
20
+ from typing import Any, Literal
21
+
22
+ import numpy as np
23
+ import resampy
24
+ import soundfile as sf
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torchaudio
28
+ from torch.utils.data import DataLoader, DistributedSampler
29
+
30
+ from NatureLM.dist_utils import get_rank, get_world_size
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ TARGET_SAMPLE_RATE = 16_000
36
+
37
+
38
+ def snr_scale(clean, noise, snr):
39
+ # Ensure both clean and noise have the same length
40
+ assert clean.shape == noise.shape, "Clean and noise must have the same shape."
41
+
42
+ # Compute power (mean squared amplitude)
43
+ power_signal = torch.mean(clean**2)
44
+ power_noise = torch.mean(noise**2)
45
+
46
+ # Prevent division by zero
47
+ epsilon = 1e-10
48
+ power_noise = torch.clamp(power_noise, min=epsilon)
49
+
50
+ # Calculate desired noise power based on SNR
51
+ desired_noise_power = power_signal / (10 ** (snr / 10))
52
+
53
+ # Scale noise to achieve the desired noise power
54
+ scale = torch.sqrt(desired_noise_power / power_noise)
55
+ scaled_noise = scale * noise
56
+
57
+ return scaled_noise
58
+
59
+
60
+ def time_scale(signal, scale=2.0, rngnp=None, seed=42):
61
+ if rngnp is None:
62
+ rngnp = np.random.default_rng(seed=seed)
63
+ scaling = np.power(scale, rngnp.uniform(-1, 1))
64
+ output_size = int(signal.shape[-1] * scaling)
65
+ ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
66
+ ref1 = ref.clone().type(torch.int64)
67
+ ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64))
68
+ r = ref - ref1.type(ref.type())
69
+ scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
70
+
71
+ ## trim or zero pad to torche original size
72
+ if scaled_signal.shape[-1] > signal.shape[-1]:
73
+ nframes_offset = (scaled_signal.shape[-1] - signal.shape[-1]) // 2
74
+ scaled_signal = scaled_signal[..., nframes_offset : nframes_offset + signal.shape[-1]]
75
+ else:
76
+ nframes_diff = signal.shape[-1] - scaled_signal.shape[-1]
77
+ pad_left = int(np.random.uniform() * nframes_diff)
78
+ pad_right = nframes_diff - pad_left
79
+ scaled_signal = F.pad(input=scaled_signal, pad=(pad_left, pad_right), mode="constant", value=0)
80
+ return scaled_signal
81
+
82
+
83
+ def mel_frequencies(n_mels, fmin, fmax):
84
+ def _hz_to_mel(f):
85
+ return 2595 * np.log10(1 + f / 700)
86
+
87
+ def _mel_to_hz(m):
88
+ return 700 * (10 ** (m / 2595) - 1)
89
+
90
+ low = _hz_to_mel(fmin)
91
+ high = _hz_to_mel(fmax)
92
+
93
+ mels = np.linspace(low, high, n_mels)
94
+
95
+ return _mel_to_hz(mels)
96
+
97
+
98
+ def now_as_str() -> str:
99
+ return datetime.now().strftime("%Y%m%d%H%M")
100
+
101
+
102
+ def get_dataloader(dataset, config, is_train=True, use_distributed=True):
103
+ if use_distributed:
104
+ sampler = DistributedSampler(dataset, shuffle=is_train, num_replicas=get_world_size(), rank=get_rank())
105
+ else:
106
+ sampler = None
107
+
108
+ loader = DataLoader(
109
+ dataset,
110
+ batch_size=config.batch_size_train if is_train else config.batch_size_eval,
111
+ num_workers=config.num_workers,
112
+ pin_memory=False,
113
+ sampler=sampler,
114
+ shuffle=sampler is None and is_train,
115
+ collate_fn=dataset.collater,
116
+ drop_last=is_train,
117
+ )
118
+
119
+ if is_train:
120
+ loader = IterLoader(loader, use_distributed=use_distributed)
121
+
122
+ return loader
123
+
124
+
125
+ def apply_to_sample(f, sample):
126
+ if len(sample) == 0:
127
+ return {}
128
+
129
+ def _apply(x):
130
+ if torch.is_tensor(x):
131
+ return f(x)
132
+ elif isinstance(x, dict):
133
+ return {key: _apply(value) for key, value in x.items()}
134
+ elif isinstance(x, list):
135
+ return [_apply(x) for x in x]
136
+ else:
137
+ return x
138
+
139
+ return _apply(sample)
140
+
141
+
142
+ def move_to_device(sample, device):
143
+ def _move_to_device(tensor):
144
+ return tensor.to(device)
145
+
146
+ return apply_to_sample(_move_to_device, sample)
147
+
148
+
149
+ def prepare_sample(samples, cuda_enabled=True):
150
+ if cuda_enabled:
151
+ samples = move_to_device(samples, "cuda")
152
+
153
+ # TODO fp16 support
154
+
155
+ return samples
156
+
157
+
158
+ def prepare_sample_dist(samples, device):
159
+ samples = move_to_device(samples, device)
160
+
161
+ # TODO fp16 support
162
+
163
+ return samples
164
+
165
+
166
+ class IterLoader:
167
+ """
168
+ A wrapper to convert DataLoader as an infinite iterator.
169
+
170
+ Modified from:
171
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
172
+ """
173
+
174
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
175
+ self._dataloader = dataloader
176
+ self.iter_loader = iter(self._dataloader)
177
+ self._use_distributed = use_distributed
178
+ self._epoch = 0
179
+
180
+ @property
181
+ def epoch(self) -> int:
182
+ return self._epoch
183
+
184
+ def __next__(self):
185
+ try:
186
+ data = next(self.iter_loader)
187
+ except StopIteration:
188
+ self._epoch += 1
189
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
190
+ self._dataloader.sampler.set_epoch(self._epoch)
191
+ time.sleep(2) # Prevent possible deadlock during epoch transition
192
+ self.iter_loader = iter(self._dataloader)
193
+ data = next(self.iter_loader)
194
+
195
+ return data
196
+
197
+ def __iter__(self):
198
+ return self
199
+
200
+ def __len__(self):
201
+ return len(self._dataloader)
202
+
203
+
204
+ def prepare_one_sample(wav_path: str, wav_processor=None, cuda_enabled=True) -> dict:
205
+ """Prepare a single sample for inference.
206
+
207
+ Args:
208
+ wav_path: Path to the audio file.
209
+ wav_processor: A function to process the audio file.
210
+ cuda_enabled: Whether to move the sample to the GPU.
211
+ """
212
+ audio, sr = sf.read(wav_path)
213
+ if len(audio.shape) == 2: # stereo to mono
214
+ audio = audio.mean(axis=1)
215
+ if len(audio) < sr: # pad audio to at least 1s
216
+ sil = np.zeros(sr - len(audio), dtype=float)
217
+ audio = np.concatenate((audio, sil), axis=0)
218
+ audio = audio[: sr * 10] # truncate audio to at most 10s
219
+
220
+ # spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"]
221
+ print("audio shape", audio.shape)
222
+
223
+ audio_t = torch.tensor(audio).unsqueeze(0)
224
+ audio_t = torchaudio.functional.resample(audio_t, sr, TARGET_SAMPLE_RATE)
225
+ print("audio shape after resample", audio_t.shape)
226
+
227
+ samples = {
228
+ "raw_wav": audio_t,
229
+ "padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0),
230
+ "audio_chunk_sizes": [1],
231
+ }
232
+ if cuda_enabled:
233
+ samples = move_to_device(samples, "cuda")
234
+
235
+ return samples
236
+
237
+
238
+ def prepare_one_sample_waveform(audio, cuda_enabled=True, sr=16000):
239
+ print("shape", audio.shape)
240
+ if len(audio.shape) == 2: # stereo to mono
241
+ print("converting stereo to mono?")
242
+ audio = audio.mean(axis=1)
243
+ if len(audio) < sr: # pad audio to at least 1s
244
+ sil = np.zeros(sr - len(audio), dtype=float)
245
+ audio = np.concatenate((audio, sil), axis=0)
246
+ audio = audio[: sr * 10] # truncate audio to at most 30s
247
+
248
+ samples = {
249
+ "raw_wav": torch.tensor(audio).unsqueeze(0).type(torch.DoubleTensor),
250
+ "padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0),
251
+ }
252
+ if cuda_enabled:
253
+ samples = move_to_device(samples, "cuda")
254
+
255
+ return samples
256
+
257
+
258
+ def prepare_sample_waveforms(audio_paths, cuda_enabled=True, sr=TARGET_SAMPLE_RATE, max_length_seconds=10):
259
+ batch_len = sr # minimum length of audio
260
+ audios = []
261
+ for audio_path in audio_paths:
262
+ audio, loaded_sr = sf.read(audio_path)
263
+ if len(audio.shape) == 2:
264
+ audio = audio[:, 0]
265
+ audio = audio[: loaded_sr * 10]
266
+ audio = resampy.resample(audio, loaded_sr, sr)
267
+ audio = torch.from_numpy(audio)
268
+
269
+ if len(audio) < sr * max_length_seconds:
270
+ pad_size = sr * max_length_seconds - len(audio)
271
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
272
+ audio = torch.clamp(audio, -1.0, 1.0)
273
+ if len(audio) > batch_len:
274
+ batch_len = len(audio)
275
+ audios.append(audio)
276
+ padding_mask = torch.zeros((len(audios), batch_len), dtype=torch.bool)
277
+ for i in range(len(audios)):
278
+ if len(audios[i]) < batch_len:
279
+ pad_len = batch_len - len(audios[i])
280
+ sil = torch.zeros(pad_len, dtype=torch.float32)
281
+ audios[i] = torch.cat((audios[i], sil), dim=0)
282
+ padding_mask[i, len(audios[i]) :] = True
283
+ audios = torch.stack(audios, dim=0)
284
+
285
+ samples = {
286
+ "raw_wav": audios,
287
+ "padding_mask": padding_mask,
288
+ "audio_chunk_sizes": [len(audio_paths)],
289
+ }
290
+ if cuda_enabled:
291
+ samples = move_to_device(samples, "cuda")
292
+
293
+ return samples
294
+
295
+
296
+ def generate_sample_batches(
297
+ audio_path,
298
+ cuda_enabled: bool = True,
299
+ sr: int = TARGET_SAMPLE_RATE,
300
+ chunk_len: int = 10,
301
+ hop_len: int = 5,
302
+ batch_size: int = 4,
303
+ ):
304
+ audio, loaded_sr = sf.read(audio_path)
305
+ if len(audio.shape) == 2: # stereo to mono
306
+ audio = audio.mean(axis=1)
307
+ audio = torchaudio.functional.resample(torch.from_numpy(audio), loaded_sr, sr)
308
+ hop_len = hop_len * sr
309
+ chunk_len = max(len(audio), chunk_len * sr)
310
+ chunks = []
311
+
312
+ for i in range(0, len(audio), hop_len):
313
+ chunk = audio[i : i + chunk_len]
314
+ if len(chunk) < chunk_len:
315
+ break
316
+ chunks.append(chunk)
317
+
318
+ for i in range(0, len(chunks), batch_size):
319
+ batch = chunks[i : i + batch_size]
320
+ padding_mask = torch.zeros((len(batch), sr * chunk_len), dtype=torch.bool)
321
+ batch = torch.stack(batch, dim=0)
322
+ samples = {
323
+ "raw_wav": batch,
324
+ "padding_mask": padding_mask,
325
+ "audio_chunk_sizes": [1 for _ in range(len(batch))],
326
+ }
327
+ if cuda_enabled:
328
+ samples = move_to_device(samples, "cuda")
329
+ yield samples
330
+
331
+
332
+ def prepare_samples_for_detection(samples, prompt, label):
333
+ prompts = [prompt for i in range(len(samples["raw_wav"]))]
334
+ labels = [label for i in range(len(samples["raw_wav"]))]
335
+ task = ["detection" for i in range(len(samples["raw_wav"]))]
336
+ samples["prompt"] = prompts
337
+ samples["text"] = labels
338
+ samples["task"] = task
339
+ return samples
340
+
341
+
342
+ def universal_torch_load(
343
+ f: str | os.PathLike,
344
+ *,
345
+ cache_mode: Literal["none", "use", "force"] = "none",
346
+ **kwargs,
347
+ ) -> Any:
348
+ """
349
+ Wrapper function for torch.load that can handle GCS paths.
350
+
351
+ This function provides a convenient way to load PyTorch objects from both local and
352
+ Google Cloud Storage (GCS) paths. For GCS paths, it can optionally caches the
353
+ downloaded files locally to avoid repeated downloads.
354
+
355
+ The cache location is determined by:
356
+ 1. The ESP_CACHE_HOME environment variable if set
357
+ 2. Otherwise defaults to ~/.cache/esp/
358
+
359
+ Args:
360
+ f: File-like object, string or PathLike object.
361
+ Can be a local path or a GCS path (starting with 'gs://').
362
+ cache_mode (str, optional): Cache mode for GCS files. Options are:
363
+ "none": No caching (use bucket directly)
364
+ "use": Use cache if available, download if not
365
+ "force": Force redownload even if cache exists
366
+ Defaults to "none".
367
+ **kwargs: Additional keyword arguments passed to torch.load().
368
+
369
+ Returns:
370
+ The object loaded from the file using torch.load.
371
+
372
+ Raises:
373
+ IsADirectoryError: If the GCS path points to a directory instead of a file.
374
+ FileNotFoundError: If the local file does not exist.
375
+ """
376
+
377
+ f = Path(f)
378
+ if not f.exists():
379
+ raise FileNotFoundError(f"File does not exist: {f}")
380
+
381
+ with open(f, "rb") as opened_file:
382
+ return torch.load(opened_file, **kwargs)
README.md CHANGED
@@ -1,14 +1,34 @@
1
  ---
2
- title: NatureLM Audio
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.40.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Description
12
  ---
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: NatureLM Audio Demo
3
+ emoji: 🎵
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Audio analysis with NatureLM model
12
  ---
13
 
14
+ # NatureLM Audio Demo
15
+
16
+ This is a demo of the NatureLM audio analysis model. The app provides three main features:
17
+
18
+ ## Features
19
+
20
+ 1. **Chat Interface**: Upload audio files and ask questions about them
21
+ 2. **Batch Processing**: Process multiple audio files with the same task
22
+ 3. **Long Recording Analysis**: Analyze long audio recordings by chunking them
23
+
24
+ ## Usage
25
+
26
+ - **First Use**: The model will load automatically when you first use it (this may take a few minutes)
27
+ - **Subsequent Uses**: The model stays loaded for faster responses
28
+ - **Demo Mode**: If the model fails to load, the app will run in demo mode
29
+
30
+ ## Model Loading
31
+
32
+ The app uses lazy loading to start quickly. The model is only loaded when you first interact with it, not during app initialization. This prevents timeout issues on HuggingFace Spaces.
33
+
34
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Space.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sdk: gradio
2
+ python_version: 3.10
3
+ hardware: cpu
configs/inference.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ llama_path: "meta-llama/Meta-Llama-3.1-8B-Instruct"
3
+
4
+ freeze_beats: True
5
+ device: "cuda"
6
+ use_audio_Qformer: True
7
+ max_pooling: False
8
+ downsample_factor: 8
9
+ freeze_audio_QFormer: False
10
+ window_level_Qformer: True
11
+ num_audio_query_token: 1
12
+ second_per_window: 0.333333
13
+ second_stride: 0.333333
14
+
15
+ audio_llama_proj_model: ""
16
+ freeze_audio_llama_proj: False
17
+
18
+ lora: True
19
+ lora_rank: 32
20
+ lora_alpha: 32
21
+ lora_dropout: 0.1
22
+
23
+ prompt_template: "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
24
+ max_txt_len: 160
25
+ end_sym: <|end_of_text|>
26
+
27
+ beats_cfg:
28
+ input_patch_size: 16
29
+ embed_dim: 512
30
+ conv_bias: False
31
+ encoder_layers: 12
32
+ encoder_embed_dim: 768
33
+ encoder_ffn_embed_dim: 3072
34
+ encoder_attention_heads: 12
35
+ activation_fn: "gelu"
36
+ layer_wise_gradient_decay_ratio: 0.6
37
+ layer_norm_first: False
38
+ deep_norm: True
39
+ dropout: 0.0
40
+ attention_dropout: 0.0
41
+ activation_dropout: 0.0
42
+ encoder_layerdrop: 0.05
43
+ dropout_input: 0.0
44
+ conv_pos: 128
45
+ conv_pos_groups: 16
46
+ relative_position_embedding: True
47
+ num_buckets: 320
48
+ max_distance: 800
49
+ gru_rel_pos: True
50
+ finetuned_model: True
51
+ predictor_dropout: 0.0
52
+ predictor_class: 527
53
+
54
+ generate:
55
+ max_new_tokens: 300
56
+ num_beams: 2
57
+ do_sample: False
58
+ min_length: 1
59
+ temperature: 0.1
60
+ repetition_penalty: 1.0
61
+ length_penalty: 1.0
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.2.2
2
+ torchaudio>=2.2.2
3
+ torchvision>=0.17.2
4
+ transformers[sentencepiece]>=4.44.2
5
+ datasets>=2.20.0
6
+ cloudpathlib[gs]>=0.20.0
7
+ einops>=0.8.0
8
+ gradio>=5.10.0
9
+ google-cloud-aiplatform>=1.76.0
10
+ Levenshtein>=0.25.1
11
+ librosa>=0.9.2
12
+ memoization>=0.4.0
13
+ mir-eval>=0.7
14
+ numpy>=1.26.4
15
+ pandas>=1.4.3
16
+ peft>=0.11.1
17
+ plumbum>=1.7.2
18
+ pydantic-settings>=2.7.1
19
+ pydantic>=2.7.4
20
+ pydub>=0.25.1
21
+ pyyaml>=6.0
22
+ resampy>=0.3.1
23
+ scipy>=1.14.0
24
+ soundfile>=0.12.1
25
+ tensorboard>=2.18.0
26
+ tensorboardX>=2.6.2.2
27
+ tqdm>=4.66.4
28
+ wandb>=0.17.3
29
+ click>=8.1.7
30
+ git+https://github.com/earthspecies/beans-zero.git