Spaces:
Running
on
Zero
Running
on
Zero
Cheeky Sparrow
commited on
Commit
·
426874e
1
Parent(s):
2c6a5a0
push
Browse files- .gitattributes +11 -0
- NatureLM/__init__.py +19 -0
- NatureLM/augmentations.py +349 -0
- NatureLM/checkpoint_utils.py +100 -0
- NatureLM/config.py +234 -0
- NatureLM/dataset.py +550 -0
- NatureLM/dist_utils.py +109 -0
- NatureLM/infer.py +315 -0
- NatureLM/logger.py +190 -0
- NatureLM/models/NatureLM.py +666 -0
- NatureLM/models/Qformer.py +1091 -0
- NatureLM/models/__init__.py +19 -0
- NatureLM/models/__pycache__/NatureLM.cpython-310.pyc +0 -0
- NatureLM/models/__pycache__/Qformer.cpython-310.pyc +0 -0
- NatureLM/models/__pycache__/__init__.cpython-310.pyc +0 -0
- NatureLM/models/__pycache__/utils.cpython-310.pyc +0 -0
- NatureLM/models/aves.py +59 -0
- NatureLM/models/beats/BEATs.py +181 -0
- NatureLM/models/beats/Tokenizers.py +173 -0
- NatureLM/models/beats/__init__.py +0 -0
- NatureLM/models/beats/__pycache__/BEATs.cpython-310.pyc +0 -0
- NatureLM/models/beats/__pycache__/__init__.cpython-310.pyc +0 -0
- NatureLM/models/beats/__pycache__/backbone.cpython-310.pyc +0 -0
- NatureLM/models/beats/__pycache__/modules.cpython-310.pyc +0 -0
- NatureLM/models/beats/backbone.py +741 -0
- NatureLM/models/beats/modules.py +201 -0
- NatureLM/models/beats/quantizer.py +222 -0
- NatureLM/models/utils.py +29 -0
- NatureLM/optims.py +154 -0
- NatureLM/processors.py +278 -0
- NatureLM/runner.py +515 -0
- NatureLM/storage_utils.py +26 -0
- NatureLM/task_metric_utils.py +283 -0
- NatureLM/task_metrics.py +128 -0
- NatureLM/utils.py +382 -0
- README.md +26 -6
- Space.yaml +3 -0
- configs/inference.yml +61 -0
- requirements.txt +30 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/484366__spacejoe__bird-3.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/nri-battlesounds.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/nri-GreenTreeFrogEvergladesNP.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/nri-StreamMUWO.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/esp_favicon.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/naturelm-audio-overiew.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/nri-SensationJazz.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/yell-YELLAMRO20160506SM3.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/yell-YELLFLBCSACR20075171.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
assets/yell-YELLWolfvCar20160111T22ms2.mp3 filter=lfs diff=lfs merge=lfs -text
|
NatureLM/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Earth Species Project
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .config import Config
|
| 16 |
+
from .models.NatureLM import NatureLM
|
| 17 |
+
from .utils import generate_sample_batches, prepare_sample_waveforms
|
| 18 |
+
|
| 19 |
+
__all__ = ["Config", "NatureLM", "generate_sample_batches", "prepare_sample_waveforms"]
|
NatureLM/augmentations.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch as th
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from NatureLM.utils import mel_frequencies
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RevEcho(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Hacky Reverb but runs on GPU without slowing down training. This reverb adds a
|
| 17 |
+
succession of attenuated echos of the input signal to itself. Intuitively, the delay
|
| 18 |
+
of the first echo will happen after roughly 2x the radius of the room and is
|
| 19 |
+
controlled by `first_delay`. Then RevEcho keeps adding echos with the same delay and
|
| 20 |
+
further attenuation until the amplitude ratio between the last and first echo is
|
| 21 |
+
1e-3. The attenuation factor and the number of echos to adds is controlled by RT60
|
| 22 |
+
(measured in seconds). RT60 is the average time to get to -60dB (n.b. volume is
|
| 23 |
+
measured over the squared amplitude so this matches the 1e-3 ratio).
|
| 24 |
+
|
| 25 |
+
At each call to RevEcho, `first_delay`, `initial` and `RT60` are sampled from their
|
| 26 |
+
range. Then, to prevent this reverb from being too regular, the delay time is
|
| 27 |
+
resampled uniformly within `first_delay +/- 10%`, as controlled by the `jitter`
|
| 28 |
+
parameter.
|
| 29 |
+
|
| 30 |
+
Finally, for a denser reverb, multiple trains of echos are added with different
|
| 31 |
+
jitter noises.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
- initial: amplitude of the first echo as a fraction of the input signal. For
|
| 35 |
+
each sample, actually sampled from `[0, initial]`. Larger values means louder
|
| 36 |
+
reverb. Physically, this would depend on the absorption of the room walls.
|
| 37 |
+
- rt60: range of values to sample the RT60 in seconds, i.e. after RT60 seconds,
|
| 38 |
+
the echo amplitude is 1e-3 of the first echo. The default values follow the
|
| 39 |
+
recommendations of https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf,
|
| 40 |
+
Section 2.4. Physically this would also be related to the absorption of the
|
| 41 |
+
room walls and there is likely a relation between `RT60` and `initial`, which
|
| 42 |
+
we ignore here.
|
| 43 |
+
- first_delay: range of values to sample the first echo delay in seconds. The
|
| 44 |
+
default values are equivalent to sampling a room of 3 to 10 meters.
|
| 45 |
+
- repeat: how many train of echos with differents jitters to add. Higher values
|
| 46 |
+
means a denser reverb.
|
| 47 |
+
- jitter: jitter used to make each repetition of the reverb echo train slightly
|
| 48 |
+
different. For instance a jitter of 0.1 means the delay between two echos will
|
| 49 |
+
be in the range `first_delay +- 10%`, with the jittering noise being resampled
|
| 50 |
+
after each single echo.
|
| 51 |
+
- keep_clean: fraction of the reverb of the clean speech to add back to the
|
| 52 |
+
ground truth. 0 = dereverberation, 1 = no dereverberation.
|
| 53 |
+
- sample_rate: sample rate of the input signals.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
proba=0.5,
|
| 59 |
+
initial=0.3,
|
| 60 |
+
rt60=(0.3, 1.3),
|
| 61 |
+
first_delay=(0.01, 0.03),
|
| 62 |
+
repeat=3,
|
| 63 |
+
jitter=0.1,
|
| 64 |
+
keep_clean=0.1,
|
| 65 |
+
sample_rate=16000,
|
| 66 |
+
rng=None,
|
| 67 |
+
seed=42,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.proba = proba
|
| 72 |
+
self.initial = initial
|
| 73 |
+
self.rt60 = rt60
|
| 74 |
+
self.first_delay = first_delay
|
| 75 |
+
self.repeat = repeat
|
| 76 |
+
self.jitter = jitter
|
| 77 |
+
self.keep_clean = keep_clean
|
| 78 |
+
self.sample_rate = sample_rate
|
| 79 |
+
self.seed = seed
|
| 80 |
+
self.rng = rng if rng is not None else random.Random(self.seed)
|
| 81 |
+
|
| 82 |
+
def _reverb(self, source, initial, first_delay, rt60):
|
| 83 |
+
"""
|
| 84 |
+
Return the reverb for a single source.
|
| 85 |
+
"""
|
| 86 |
+
length = source.shape[-1]
|
| 87 |
+
reverb = th.zeros_like(source)
|
| 88 |
+
|
| 89 |
+
for _ in range(self.repeat):
|
| 90 |
+
frac = 1 # what fraction of the first echo amplitude is still here
|
| 91 |
+
echo = initial * source
|
| 92 |
+
while frac > 1e-3:
|
| 93 |
+
# First jitter noise for the delay
|
| 94 |
+
jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
|
| 95 |
+
delay = min(1 + int(jitter * first_delay * self.sample_rate), length)
|
| 96 |
+
|
| 97 |
+
# Delay the echo in time by padding with zero on the left
|
| 98 |
+
echo = F.pad(echo[:, :, :-delay], (delay, 0))
|
| 99 |
+
reverb += echo
|
| 100 |
+
|
| 101 |
+
# Second jitter noise for the attenuation
|
| 102 |
+
jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
|
| 103 |
+
# we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
|
| 104 |
+
# i.e. log10(d) = -3 * first_ms / rt60, so that
|
| 105 |
+
attenuation = 10 ** (-3 * jitter * first_delay / rt60)
|
| 106 |
+
echo *= attenuation
|
| 107 |
+
frac *= attenuation
|
| 108 |
+
|
| 109 |
+
return reverb
|
| 110 |
+
|
| 111 |
+
def forward(self, samples):
|
| 112 |
+
if self.rng.random() >= self.proba:
|
| 113 |
+
return samples
|
| 114 |
+
|
| 115 |
+
raw_wav = samples.get("raw_wav", None)
|
| 116 |
+
|
| 117 |
+
# add channel dimension if not exist
|
| 118 |
+
if raw_wav.dim() == 2:
|
| 119 |
+
raw_wav = raw_wav.unsqueeze(1)
|
| 120 |
+
|
| 121 |
+
# Sample characteristics for the reverb
|
| 122 |
+
initial = self.rng.random() * self.initial
|
| 123 |
+
first_delay = self.rng.uniform(*self.first_delay)
|
| 124 |
+
rt60 = self.rng.uniform(*self.rt60)
|
| 125 |
+
|
| 126 |
+
reverb_wav = self._reverb(raw_wav, initial, first_delay, rt60)
|
| 127 |
+
raw_wav += self.keep_clean * reverb_wav
|
| 128 |
+
|
| 129 |
+
# remove channel dimension
|
| 130 |
+
if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
|
| 131 |
+
raw_wav = raw_wav.squeeze(1)
|
| 132 |
+
|
| 133 |
+
samples["raw_wav"] = raw_wav
|
| 134 |
+
return samples
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class BandMask(nn.Module):
|
| 138 |
+
"""
|
| 139 |
+
Maskes bands of frequencies. Similar to Park, Daniel S., et al.
|
| 140 |
+
"Specaugment: A simple data augmentation method for automatic speech recognition."
|
| 141 |
+
(https://arxiv.org/pdf/1904.08779.pdf) but over the waveform.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000, rng=None, seed=42):
|
| 145 |
+
"""__init__.
|
| 146 |
+
|
| 147 |
+
:param maxwidth: the maximum width to remove
|
| 148 |
+
:param bands: number of bands
|
| 149 |
+
:param sample_rate: signal sample rate
|
| 150 |
+
"""
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.maxwidth = maxwidth
|
| 153 |
+
self.bands = bands
|
| 154 |
+
self.sample_rate = sample_rate
|
| 155 |
+
self.seed = seed
|
| 156 |
+
self.rng = rng if rng is not None else random.Random(self.seed)
|
| 157 |
+
|
| 158 |
+
def forward(self, samples):
|
| 159 |
+
raw_wav = samples.get("raw_wav", None)
|
| 160 |
+
|
| 161 |
+
# add channel dimension if not exist
|
| 162 |
+
if raw_wav.dim() == 2:
|
| 163 |
+
raw_wav = raw_wav.unsqueeze(1)
|
| 164 |
+
|
| 165 |
+
bands = self.bands
|
| 166 |
+
bandwidth = int(abs(self.maxwidth) * bands)
|
| 167 |
+
mels = mel_frequencies(bands, 40, self.sample_rate / 2) / self.sample_rate
|
| 168 |
+
low = self.rng.randrange(bands)
|
| 169 |
+
high = self.rng.randrange(low, min(bands, low + bandwidth))
|
| 170 |
+
|
| 171 |
+
filters = LowPassFilters([mels[low], mels[high]]).to(raw_wav.device)
|
| 172 |
+
|
| 173 |
+
low, midlow = filters(raw_wav)
|
| 174 |
+
# band pass filtering
|
| 175 |
+
out = raw_wav - midlow + low
|
| 176 |
+
|
| 177 |
+
# remove channel dimension
|
| 178 |
+
if out.dim() == 3 and out.shape[1] == 1:
|
| 179 |
+
out = out.squeeze(1)
|
| 180 |
+
|
| 181 |
+
samples["raw_wav"] = out
|
| 182 |
+
return samples
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class Shift(nn.Module):
|
| 186 |
+
def __init__(self, shift=8192, same=False, rngth=None):
|
| 187 |
+
"""
|
| 188 |
+
:param shift: randomly shifts the signals up to a given factor
|
| 189 |
+
:param same: shifts both clean and noisy files by the same factor
|
| 190 |
+
"""
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.shift = shift
|
| 193 |
+
self.same = same
|
| 194 |
+
self.rngth = rngth
|
| 195 |
+
|
| 196 |
+
def forward(self, samples):
|
| 197 |
+
raw_wav = samples.get("raw_wav", None)
|
| 198 |
+
batch, channels, length = raw_wav.shape
|
| 199 |
+
length = length - self.shift
|
| 200 |
+
if self.shift > 0:
|
| 201 |
+
offsets = th.randint(
|
| 202 |
+
self.shift, [1 if self.same else batch, 1, 1], device=raw_wav.device, generator=self.rngth
|
| 203 |
+
)
|
| 204 |
+
offsets = offsets.expand(-1, channels, -1)
|
| 205 |
+
indexes = th.arange(length, device=raw_wav.device)
|
| 206 |
+
import pdb
|
| 207 |
+
|
| 208 |
+
pdb.set_trace()
|
| 209 |
+
raw_wav = raw_wav.gather(2, indexes + offsets)
|
| 210 |
+
samples["raw_wav"] = raw_wav
|
| 211 |
+
return samples
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class TimeScale(nn.Module):
|
| 215 |
+
"""Fast time scale."""
|
| 216 |
+
|
| 217 |
+
def __init__(self, scale=2.0, target=1, rngnp=None, seed=42):
|
| 218 |
+
"""
|
| 219 |
+
:param scale: randomly scales up to this maximum factor
|
| 220 |
+
"""
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.scale = scale
|
| 223 |
+
self.target = target
|
| 224 |
+
self.seed = seed
|
| 225 |
+
self.rngnp = rngnp if rngnp is not None else np.random.default_rng(seed=self.seed)
|
| 226 |
+
|
| 227 |
+
def forward(self, samples):
|
| 228 |
+
try:
|
| 229 |
+
raw_wav = samples.get("raw_wav")
|
| 230 |
+
except KeyError:
|
| 231 |
+
logger.error("Missing required key 'raw_wav' in samples dict")
|
| 232 |
+
raise
|
| 233 |
+
|
| 234 |
+
if "padding_mask" in samples:
|
| 235 |
+
masks = samples.get("padding_mask")
|
| 236 |
+
else:
|
| 237 |
+
masks = th.ones_like(raw_wav)
|
| 238 |
+
|
| 239 |
+
# add channel dimension if not exist
|
| 240 |
+
if raw_wav.dim() == 2:
|
| 241 |
+
raw_wav = raw_wav.unsqueeze(1)
|
| 242 |
+
masks = masks.unsqueeze(1)
|
| 243 |
+
|
| 244 |
+
# what to augment: noise, clean, or both
|
| 245 |
+
if self.target == -1:
|
| 246 |
+
targets = [i for i in range(raw_wav.shape[0])]
|
| 247 |
+
else:
|
| 248 |
+
targets = [self.target]
|
| 249 |
+
|
| 250 |
+
for t in targets:
|
| 251 |
+
signal = raw_wav[t]
|
| 252 |
+
scaling = np.power(self.scale, self.rngnp.uniform(-1, 1))
|
| 253 |
+
output_size = int(signal.shape[-1] * scaling)
|
| 254 |
+
ref = th.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
|
| 255 |
+
|
| 256 |
+
ref1 = ref.clone().type(th.int64)
|
| 257 |
+
ref2 = th.min(ref1 + 1, th.full_like(ref1, signal.shape[-1] - 1, dtype=th.int64))
|
| 258 |
+
r = ref - ref1.type(ref.type())
|
| 259 |
+
scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
|
| 260 |
+
scaled_masks = masks[t][..., ref1] * (1 - r) + masks[t][..., ref2] * r
|
| 261 |
+
|
| 262 |
+
# trim or zero pad to the original size
|
| 263 |
+
if scaled_signal.shape[-1] > signal.shape[-1]:
|
| 264 |
+
nframes_offset = (scaled_signal.shape[-1] - signal.shape[-1]) // 2
|
| 265 |
+
scaled_signal = scaled_signal[..., nframes_offset : nframes_offset + signal.shape[-1]]
|
| 266 |
+
scaled_masks = scaled_masks[..., nframes_offset : nframes_offset + signal.shape[-1]]
|
| 267 |
+
else:
|
| 268 |
+
nframes_diff = signal.shape[-1] - scaled_signal.shape[-1]
|
| 269 |
+
pad_left = int(np.random.uniform() * nframes_diff)
|
| 270 |
+
pad_right = nframes_diff - pad_left
|
| 271 |
+
scaled_signal = F.pad(
|
| 272 |
+
input=scaled_signal, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
|
| 273 |
+
)
|
| 274 |
+
scaled_masks = F.pad(
|
| 275 |
+
input=scaled_masks, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
|
| 276 |
+
)
|
| 277 |
+
raw_wav[t] = scaled_signal
|
| 278 |
+
masks[t] = scaled_masks
|
| 279 |
+
|
| 280 |
+
# remove channel dimension
|
| 281 |
+
if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
|
| 282 |
+
raw_wav = raw_wav.squeeze(1)
|
| 283 |
+
masks = masks.squeeze(1)
|
| 284 |
+
|
| 285 |
+
samples["raw_wav"] = raw_wav
|
| 286 |
+
samples["padding_mask"] = masks
|
| 287 |
+
|
| 288 |
+
return samples
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class Flip(nn.Module):
|
| 292 |
+
def __init__(self, p=0.0, rngth=None):
|
| 293 |
+
super(Flip, self).__init__()
|
| 294 |
+
|
| 295 |
+
self.p = p
|
| 296 |
+
self.rngth = rngth
|
| 297 |
+
|
| 298 |
+
def forward(self, samples):
|
| 299 |
+
raw_wav = samples["raw_wav"]
|
| 300 |
+
if raw_wav.dim() > 2:
|
| 301 |
+
flip_mask = th.rand(raw_wav.shape[0], device=raw_wav.device, generator=self.rngth) <= self.p
|
| 302 |
+
raw_wav[flip_mask] = raw_wav[flip_mask].flip(-1)
|
| 303 |
+
else:
|
| 304 |
+
if th.rand(1, generator=self.rngth) <= self.p:
|
| 305 |
+
raw_wav = raw_wav.flip(0)
|
| 306 |
+
samples["raw_wav"] = raw_wav
|
| 307 |
+
return samples
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class LowPassFilters(th.nn.Module):
|
| 311 |
+
"""
|
| 312 |
+
Bank of low pass filters.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
cutoffs (list[float]): list of cutoff frequencies, in [0, 1] expressed as `f/f_s` where
|
| 316 |
+
f_s is the samplerate.
|
| 317 |
+
width (int | None): width of the filters (i.e. kernel_size=2 * width + 1).
|
| 318 |
+
Default to `2 / min(cutoffs)`. Longer filters will have better attenuation
|
| 319 |
+
but more side effects.
|
| 320 |
+
Shape:
|
| 321 |
+
- Input: `(*, T)`
|
| 322 |
+
- Output: `(F, *, T` with `F` the len of `cutoffs`.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def __init__(self, cutoffs: list, width: int | None = None):
|
| 326 |
+
super().__init__()
|
| 327 |
+
|
| 328 |
+
self.cutoffs = cutoffs
|
| 329 |
+
|
| 330 |
+
if not width:
|
| 331 |
+
width = int(2 / min(cutoffs))
|
| 332 |
+
self.width = width
|
| 333 |
+
|
| 334 |
+
window = th.hamming_window(2 * width + 1, periodic=False)
|
| 335 |
+
t = np.arange(-width, width + 1, dtype=np.float32)
|
| 336 |
+
filters = []
|
| 337 |
+
for cutoff in cutoffs:
|
| 338 |
+
sinc = th.from_numpy(np.sinc(2 * cutoff * t))
|
| 339 |
+
filters.append(2 * cutoff * sinc * window)
|
| 340 |
+
self.register_buffer("filters", th.stack(filters).unsqueeze(1))
|
| 341 |
+
|
| 342 |
+
def forward(self, input):
|
| 343 |
+
*others, t = input.shape
|
| 344 |
+
input = input.view(-1, 1, t)
|
| 345 |
+
out = F.conv1d(input, self.filters, padding=self.width)
|
| 346 |
+
return out.permute(1, 0, 2).reshape(-1, *others, t)
|
| 347 |
+
|
| 348 |
+
def __repr__(self):
|
| 349 |
+
return "LossPassFilters(width={},cutoffs={})".format(self.width, self.cutoffs)
|
NatureLM/checkpoint_utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module for training utilities.
|
| 2 |
+
|
| 3 |
+
This module contains utility functions for training models. For example, saving model checkpoints.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
from typing import Any, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def maybe_unwrap_dist_model(model: nn.Module, use_distributed: bool) -> nn.Module:
|
| 18 |
+
return model.module if use_distributed else model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]:
|
| 22 |
+
"""Get model state dict. Optionally drop untrained parameters to keep only those that require gradient.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model: Model to get state dict from
|
| 26 |
+
drop_untrained_params: Whether to drop untrained parameters
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
dict: Model state dict
|
| 30 |
+
"""
|
| 31 |
+
if not drop_untrained_params:
|
| 32 |
+
return model.state_dict()
|
| 33 |
+
|
| 34 |
+
param_grad_dict = {k: v.requires_grad for (k, v) in model.named_parameters()}
|
| 35 |
+
state_dict = model.state_dict()
|
| 36 |
+
|
| 37 |
+
for k in list(state_dict.keys()):
|
| 38 |
+
if k in param_grad_dict.keys() and not param_grad_dict[k]:
|
| 39 |
+
# delete parameters that do not require gradient
|
| 40 |
+
del state_dict[k]
|
| 41 |
+
|
| 42 |
+
return state_dict
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def torch_save_to_bucket(save_obj: Any, save_path: Union[str, os.PathLike], compress: bool = True) -> None:
|
| 46 |
+
"""Save an object directly to GCS bucket without intermediate disk storage.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
save_obj: Object to save (usually model state dict or checkpoint)
|
| 50 |
+
save_path: Path to save in GCS bucket (must be gs:// path)
|
| 51 |
+
compress: Whether to use compression. Default: True
|
| 52 |
+
"""
|
| 53 |
+
if not is_gcs_path(save_path):
|
| 54 |
+
raise ValueError("save_path must be a GCS path")
|
| 55 |
+
|
| 56 |
+
# save to a temporary local file and then upload to GCS
|
| 57 |
+
with tempfile.NamedTemporaryFile() as tmp:
|
| 58 |
+
torch.save(save_obj, tmp.name, _use_new_zipfile_serialization=compress)
|
| 59 |
+
try:
|
| 60 |
+
save_path.upload_from(tmp.name)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"Error saving to GCP bucket: {e}")
|
| 63 |
+
raise e
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def save_model_checkpoint(
|
| 67 |
+
model: nn.Module,
|
| 68 |
+
save_path: Union[str, os.PathLike],
|
| 69 |
+
use_distributed: bool = False,
|
| 70 |
+
drop_untrained_params: bool = False,
|
| 71 |
+
**objects_to_save,
|
| 72 |
+
) -> None:
|
| 73 |
+
"""Save model checkpoint.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
model (nn.Module): Model to save
|
| 77 |
+
output_dir (str): Output directory to save checkpoint
|
| 78 |
+
use_distributed (bool): Whether the model is distributed, if so, unwrap it. Default: False.
|
| 79 |
+
is_best (bool): Whether the model is the best in the training run. Default: False.
|
| 80 |
+
drop_untrained_params (bool): Whether to drop untrained parameters to save. Default: True.
|
| 81 |
+
prefix (str): Prefix to add to the checkpoint file name. Default: "".
|
| 82 |
+
extention (str): Extension to use for the checkpoint file. Default: "pth".
|
| 83 |
+
**objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
|
| 84 |
+
"""
|
| 85 |
+
if not is_gcs_path(save_path) and not os.path.exists(os.path.dirname(save_path)):
|
| 86 |
+
raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")
|
| 87 |
+
|
| 88 |
+
model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
|
| 89 |
+
state_dict = get_state_dict(model_no_ddp, drop_untrained_params)
|
| 90 |
+
save_obj = {
|
| 91 |
+
"model": state_dict,
|
| 92 |
+
**objects_to_save,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
logger.info("Saving checkpoint to {}.".format(save_path))
|
| 96 |
+
|
| 97 |
+
if is_gcs_path(save_path):
|
| 98 |
+
torch_save_to_bucket(save_obj, save_path)
|
| 99 |
+
else:
|
| 100 |
+
torch.save(save_obj, save_path)
|
NatureLM/config.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Earth Species Project
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any, Literal
|
| 18 |
+
|
| 19 |
+
import yaml
|
| 20 |
+
from pydantic import BaseModel, field_validator
|
| 21 |
+
from pydantic.v1.utils import deep_update
|
| 22 |
+
from pydantic_settings import BaseSettings, CliSettingsSource, YamlConfigSettingsSource
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OptimizerConfig(BaseModel, extra="forbid", validate_assignment=True):
|
| 26 |
+
max_epoch: int
|
| 27 |
+
warmup_steps: int
|
| 28 |
+
warmup_start_lr: float = -1
|
| 29 |
+
init_lr: float
|
| 30 |
+
min_lr: float
|
| 31 |
+
weight_decay: float
|
| 32 |
+
beta2: float = 0.999
|
| 33 |
+
max_grad_norm: float | None = None
|
| 34 |
+
max_grad_value: float | None = None
|
| 35 |
+
device: str = "cuda"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AugmentationsConfig(BaseModel, extra="forbid", validate_assignment=True):
|
| 39 |
+
use_augmentation: bool = False
|
| 40 |
+
|
| 41 |
+
noise_prob: float = 0
|
| 42 |
+
noise_dirs: list[Path] | None = None
|
| 43 |
+
low_snr: float = -5
|
| 44 |
+
high_snr: float = 20
|
| 45 |
+
time_scale_prob: float = 0
|
| 46 |
+
time_scale: float = 1.2
|
| 47 |
+
mixup_prob: float = 0
|
| 48 |
+
mixup_count: int = 3
|
| 49 |
+
mask_audio_prob: float = 0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class RunConfig(BaseModel, extra="forbid", validate_assignment=True):
|
| 53 |
+
wandb_enabled: bool = True
|
| 54 |
+
amp: bool = False
|
| 55 |
+
seed: int
|
| 56 |
+
output_dir: Path
|
| 57 |
+
evaluate: bool
|
| 58 |
+
log_freq: int
|
| 59 |
+
epoch_based: bool
|
| 60 |
+
iters_per_epoch: int
|
| 61 |
+
accum_grad_iters: int
|
| 62 |
+
batch_size_train: int
|
| 63 |
+
batch_size_eval: int
|
| 64 |
+
num_workers: int
|
| 65 |
+
custom_metrics: bool
|
| 66 |
+
decode_ratio: float
|
| 67 |
+
|
| 68 |
+
device: Literal["cuda", "cpu"] = "cuda"
|
| 69 |
+
use_distributed: bool = False
|
| 70 |
+
|
| 71 |
+
world_size: int = 1
|
| 72 |
+
rank: int = 0
|
| 73 |
+
gpu: int | None = None
|
| 74 |
+
dist_backend: Literal["nccl"] = "nccl"
|
| 75 |
+
dist_url: str = "env://"
|
| 76 |
+
|
| 77 |
+
optims: OptimizerConfig
|
| 78 |
+
augmentations: AugmentationsConfig
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class DatasetsConfig(BaseModel, extra="forbid", validate_assignment=True):
|
| 82 |
+
train_ann_path: Path
|
| 83 |
+
valid_ann_path: Path
|
| 84 |
+
test_ann_path: Path
|
| 85 |
+
audio_max_length_seconds: int
|
| 86 |
+
|
| 87 |
+
@field_validator("train_ann_path", "valid_ann_path", "test_ann_path", mode="after")
|
| 88 |
+
@classmethod
|
| 89 |
+
def check_files(cls, path: Path) -> Path:
|
| 90 |
+
if not path.exists():
|
| 91 |
+
raise ValueError(f"File {path} does not exist")
|
| 92 |
+
if path.suffix.lower() != ".jsonl":
|
| 93 |
+
raise ValueError(f"File {path} must be a JSONL file")
|
| 94 |
+
return path
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class BeatsConfig(BaseModel, extra="forbid", validate_assignment=True):
|
| 98 |
+
input_patch_size: int = -1
|
| 99 |
+
embed_dim: int = 512
|
| 100 |
+
conv_bias: bool = False
|
| 101 |
+
|
| 102 |
+
encoder_layers: int = 12
|
| 103 |
+
encoder_embed_dim: int = 768
|
| 104 |
+
encoder_ffn_embed_dim: int = 3072
|
| 105 |
+
encoder_attention_heads: int = 12
|
| 106 |
+
activation_fn: str = "gelu"
|
| 107 |
+
|
| 108 |
+
layer_wise_gradient_decay_ratio: float = 0.6
|
| 109 |
+
layer_norm_first: bool = False
|
| 110 |
+
deep_norm: bool = True
|
| 111 |
+
|
| 112 |
+
dropout: float = 0.0
|
| 113 |
+
attention_dropout: float = 0.0
|
| 114 |
+
activation_dropout: float = 0.0
|
| 115 |
+
encoder_layerdrop: float = 0.05
|
| 116 |
+
dropout_input: float = 0.0
|
| 117 |
+
|
| 118 |
+
conv_pos: int = 128
|
| 119 |
+
conv_pos_groups: int = 16
|
| 120 |
+
|
| 121 |
+
relative_position_embedding: bool = True
|
| 122 |
+
num_buckets: int = 320
|
| 123 |
+
max_distance: int = 800
|
| 124 |
+
gru_rel_pos: bool = True
|
| 125 |
+
|
| 126 |
+
finetuned_model: bool = True
|
| 127 |
+
predictor_dropout: float = 0.0
|
| 128 |
+
predictor_class: int = 527
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class GenerateConfig(BaseModel, extra="forbid", validate_assignment=True):
|
| 132 |
+
max_new_tokens: int
|
| 133 |
+
num_beams: int
|
| 134 |
+
do_sample: bool
|
| 135 |
+
min_length: int
|
| 136 |
+
temperature: float
|
| 137 |
+
repetition_penalty: float
|
| 138 |
+
length_penalty: float
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ModelConfig(BaseModel, extra="forbid", validate_assignment=True):
|
| 142 |
+
llama_path: Path
|
| 143 |
+
beats_path: Path | None = None
|
| 144 |
+
beats_cfg: BeatsConfig
|
| 145 |
+
ckpt: Path | None = None
|
| 146 |
+
freeze_beats: bool = True
|
| 147 |
+
use_audio_Qformer: bool = True
|
| 148 |
+
max_pooling: bool = False
|
| 149 |
+
downsample_factor: int = 4
|
| 150 |
+
freeze_audio_QFormer: bool = False
|
| 151 |
+
window_level_Qformer: bool = True
|
| 152 |
+
num_audio_query_token: int = 1
|
| 153 |
+
second_per_window: float = 0.333333
|
| 154 |
+
second_stride: float = 0.333333
|
| 155 |
+
audio_llama_proj_model: Path | None = None
|
| 156 |
+
freeze_audio_llama_proj: bool = False
|
| 157 |
+
device: str = "cuda"
|
| 158 |
+
lora: bool = True
|
| 159 |
+
lora_rank: int = 8
|
| 160 |
+
lora_alpha: int = 32
|
| 161 |
+
lora_dropout: float = 0.1
|
| 162 |
+
flash_attn: Literal["eager", "flash_attention_2"] = "eager"
|
| 163 |
+
prompt_template: str = ""
|
| 164 |
+
max_txt_len: int = 128
|
| 165 |
+
end_sym: str = "</s>"
|
| 166 |
+
|
| 167 |
+
@field_validator("beats_path", "audio_llama_proj_model", "ckpt", mode="before")
|
| 168 |
+
@classmethod
|
| 169 |
+
def detect_gcs_path(cls, value: Any) -> Any:
|
| 170 |
+
"""Pydantic's automatic type conversion won't be able to deal with gs:// paths
|
| 171 |
+
so we need to manually detect and convert them to GSPath objects _before_
|
| 172 |
+
validation"""
|
| 173 |
+
return value
|
| 174 |
+
|
| 175 |
+
@field_validator("ckpt", "audio_llama_proj_model", mode="before")
|
| 176 |
+
@classmethod
|
| 177 |
+
def legacy_empty_str(cls, value: Any) -> Any:
|
| 178 |
+
"""In some of our config files we use "" to indicate that we don't have
|
| 179 |
+
a checkpoint. We've now switched to using None for this in the Config model but
|
| 180 |
+
let's keep this validator for backwards compatibility so people don't have to
|
| 181 |
+
change their configs"""
|
| 182 |
+
if isinstance(value, str) and value == "":
|
| 183 |
+
return None
|
| 184 |
+
else:
|
| 185 |
+
return value
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def from_yaml(cls, yaml_file: str | os.PathLike) -> "ModelConfig":
|
| 189 |
+
yaml_values = YamlConfigSettingsSource(cls, yaml_file=str(yaml_file))
|
| 190 |
+
return cls.model_validate(yaml_values())
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Config(BaseSettings, extra="forbid", validate_assignment=True):
|
| 194 |
+
model: ModelConfig
|
| 195 |
+
run: RunConfig | None = None
|
| 196 |
+
datasets: DatasetsConfig | None = None
|
| 197 |
+
generate: GenerateConfig | None = None
|
| 198 |
+
|
| 199 |
+
def pretty_print(self):
|
| 200 |
+
print(self.model_dump_json(indent=4))
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
def from_sources(cls, yaml_file: str | Path, cli_args: list[str] = []) -> "Config":
|
| 204 |
+
"""Create a Config object from a YAML file and CLI arguments. If there are
|
| 205 |
+
any conflicts, the CLI arguments will take precedence over the YAML file."""
|
| 206 |
+
|
| 207 |
+
yaml_file = Path(yaml_file)
|
| 208 |
+
if not yaml_file.exists():
|
| 209 |
+
raise FileNotFoundError(f"Config file {yaml_file} does not exist")
|
| 210 |
+
|
| 211 |
+
yaml_values = YamlConfigSettingsSource(cls, yaml_file=yaml_file)
|
| 212 |
+
cli_values = CliSettingsSource(cls, cli_parse_args=["--" + opt for opt in cli_args])
|
| 213 |
+
final_values = deep_update(yaml_values(), cli_values())
|
| 214 |
+
return cls.model_validate(final_values)
|
| 215 |
+
|
| 216 |
+
def to_yaml(self, path: str | os.PathLike) -> None:
|
| 217 |
+
save_config_as_yaml(self, path)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def save_config_as_yaml(data: BaseModel, filepath: str | os.PathLike) -> None:
|
| 221 |
+
"""
|
| 222 |
+
Pydantic supports serializing/exporting models to various formats (dict, json, etc)
|
| 223 |
+
but not to yaml. This function is a workaround for that limitation.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
filepath = Path(filepath)
|
| 227 |
+
|
| 228 |
+
if filepath.exists():
|
| 229 |
+
raise FileExistsError(f"File {filepath} already exists")
|
| 230 |
+
|
| 231 |
+
# The mode="json" is required because otherwise yaml.same_dump() can't deal with
|
| 232 |
+
# Path|GSPath objects
|
| 233 |
+
with filepath.open("w") as f:
|
| 234 |
+
yaml.safe_dump(data.model_dump(mode="json"), f, sort_keys=False)
|
NatureLM/dataset.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Earth Species Project
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Mixing examples.
|
| 18 |
+
Can mix:
|
| 19 |
+
- base: options-detection add: open-ended:
|
| 20 |
+
Take all open-ended labels. Add them to the options. Add them to the labels.
|
| 21 |
+
- base: open-ended, add: open-ended
|
| 22 |
+
Concatenate labels
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import glob
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import random
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Literal
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import soundfile as sf
|
| 35 |
+
import torch
|
| 36 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 37 |
+
from torch.utils.data import Dataset
|
| 38 |
+
|
| 39 |
+
from NatureLM.utils import snr_scale, time_scale
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def write_example_to_file(base_filename, audio, sr=16000, suffix="_output", save_dir="debug_outputs"):
|
| 43 |
+
"""
|
| 44 |
+
Writes the audio tensor to a file for debugging or inspection purposes.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
base_filename (str): The base name of the original file.
|
| 48 |
+
audio (torch.Tensor or numpy.ndarray): The audio waveform to save.
|
| 49 |
+
sr (int): Sampling rate of the audio (default: 16000 Hz).
|
| 50 |
+
suffix (str): Optional suffix to append to the filename.
|
| 51 |
+
save_dir (str): Directory where the files will be saved.
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(audio, torch.Tensor):
|
| 54 |
+
audio = audio.numpy() # Convert to numpy if necessary
|
| 55 |
+
|
| 56 |
+
# Ensure the save directory exists
|
| 57 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
# Create the output file path
|
| 60 |
+
filename = f"{os.path.splitext(base_filename)[0]}{suffix}.wav"
|
| 61 |
+
output_path = os.path.join(save_dir, filename)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# Write the audio to the file
|
| 65 |
+
sf.write(output_path, audio, sr)
|
| 66 |
+
print(f"Saved audio to {output_path}")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Failed to write audio to file: {e}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Example usage in your code
|
| 72 |
+
# write_example_to_file(os.path.basename(ann["path"]), audio, suffix="_ts")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def collater(samples):
|
| 76 |
+
"""Collate samples into a batch.
|
| 77 |
+
|
| 78 |
+
Samples is a list of dictionaries, each containing the following keys:
|
| 79 |
+
- raw_wav: a list of tensors containing the raw audio waveform
|
| 80 |
+
- text: a list of strings containing the text
|
| 81 |
+
- task: a list of strings containing the task
|
| 82 |
+
- id: a list of strings containing the id
|
| 83 |
+
- prompt: a list of strings containing the prompt
|
| 84 |
+
- index: a list of integers containing the index
|
| 85 |
+
|
| 86 |
+
The indiviudal audio waveforms will be stacked along the batch dimension for easier
|
| 87 |
+
processing in the audio model. To keep which audio belongs to which sample, we add
|
| 88 |
+
the audio_chunk_sizes key to the batch dictionary.
|
| 89 |
+
"""
|
| 90 |
+
flat_raw_wav = []
|
| 91 |
+
audio_chunk_sizes = []
|
| 92 |
+
|
| 93 |
+
for s in samples:
|
| 94 |
+
chunk_size = len(s["raw_wav"])
|
| 95 |
+
audio_chunk_sizes.append(chunk_size)
|
| 96 |
+
flat_raw_wav.extend(s["raw_wav"])
|
| 97 |
+
# raw_wav = [torch.from_numpy(a) for a in flat_raw_wav]
|
| 98 |
+
raw_wav = flat_raw_wav
|
| 99 |
+
raw_wav_length = torch.tensor([len(a) for a in raw_wav])
|
| 100 |
+
raw_wav = pad_sequence(raw_wav, batch_first=True, padding_value=0)
|
| 101 |
+
paddding_mask = torch.arange(raw_wav.size(1)).unsqueeze(0) >= raw_wav_length.unsqueeze(1)
|
| 102 |
+
|
| 103 |
+
text = [s["text"] for s in samples]
|
| 104 |
+
prompt = [s["prompt"] for s in samples]
|
| 105 |
+
task = [s["task"] for s in samples]
|
| 106 |
+
id = [s["id"] for s in samples]
|
| 107 |
+
index = [s["index"] for s in samples]
|
| 108 |
+
|
| 109 |
+
return {
|
| 110 |
+
"raw_wav": raw_wav,
|
| 111 |
+
"padding_mask": paddding_mask,
|
| 112 |
+
"text": text,
|
| 113 |
+
"task": task,
|
| 114 |
+
"id": id,
|
| 115 |
+
"prompt": prompt,
|
| 116 |
+
"index": index,
|
| 117 |
+
"audio_chunk_sizes": audio_chunk_sizes,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class NatureLMDataset(Dataset):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
ann_path: str | Path,
|
| 125 |
+
*,
|
| 126 |
+
max_length_seconds: int = 10,
|
| 127 |
+
cropping: Literal["random", "start"] | None = "random",
|
| 128 |
+
noise_prob: float = 0.0,
|
| 129 |
+
noise_dirs: list[str] | list[Path] | None = None,
|
| 130 |
+
low_snr: float = -5,
|
| 131 |
+
high_snr: float = 20,
|
| 132 |
+
time_scale_prob: float = 0.0,
|
| 133 |
+
time_scale: float = 1.2,
|
| 134 |
+
seed: int = 0,
|
| 135 |
+
mixup_prob: float = 0.0,
|
| 136 |
+
mixup_count: int = 3,
|
| 137 |
+
use_augmentation: bool = False,
|
| 138 |
+
mask_audio_prob: float = 0.0,
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
ann_path = Path(ann_path)
|
| 143 |
+
|
| 144 |
+
if not ann_path.exists():
|
| 145 |
+
raise FileNotFoundError(f"Dataset file {ann_path} not found")
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
with open(ann_path, "r") as f:
|
| 149 |
+
data = json.load(f)
|
| 150 |
+
self.annotation = data["annotation"]
|
| 151 |
+
except (json.JSONDecodeError, KeyError):
|
| 152 |
+
with open(ann_path, "r") as f:
|
| 153 |
+
self.annotation = [json.loads(line) for line in f]
|
| 154 |
+
|
| 155 |
+
#### mixup related variables
|
| 156 |
+
### hash table for tasks to sample the tasks faster
|
| 157 |
+
self.tasks = defaultdict(list)
|
| 158 |
+
for i, ann in enumerate(self.annotation):
|
| 159 |
+
if "task" in ann and "text" in ann and ann["text"] != "None" and "path" in ann:
|
| 160 |
+
self.tasks[ann["task"]].append(i)
|
| 161 |
+
|
| 162 |
+
self.mixup_tasks = {
|
| 163 |
+
task: []
|
| 164 |
+
for task in self.tasks.keys()
|
| 165 |
+
if task.endswith("simple-detection")
|
| 166 |
+
or task.endswith("multiple-detection") # Add more tasks after validating prompt mixing.
|
| 167 |
+
or task.endswith("sci-detection-random")
|
| 168 |
+
or task.endswith("common-detection-random")
|
| 169 |
+
}
|
| 170 |
+
for k in self.mixup_tasks.keys():
|
| 171 |
+
# whichever the base, only mix open-ended tasks.
|
| 172 |
+
if "sci-" in k:
|
| 173 |
+
self.mixup_tasks[k] = [
|
| 174 |
+
task
|
| 175 |
+
for task in self.mixup_tasks.keys()
|
| 176 |
+
if task.endswith("sci-simple-detection") or task.endswith("sci-multiple-detection")
|
| 177 |
+
]
|
| 178 |
+
elif "common-" in k:
|
| 179 |
+
self.mixup_tasks[k] = [
|
| 180 |
+
task
|
| 181 |
+
for task in self.mixup_tasks.keys()
|
| 182 |
+
if task.endswith("common-simple-detection") or task.endswith("common-multiple-detection")
|
| 183 |
+
]
|
| 184 |
+
else:
|
| 185 |
+
self.mixup_tasks[k] = [task for task in self.mixup_tasks.keys() if "common-" in task]
|
| 186 |
+
|
| 187 |
+
# print("num annotations", len(self.annotation))
|
| 188 |
+
# print("annotation 0", self.annotation[0])
|
| 189 |
+
# self.annotation = [a for a in self.annotation if "task" in a and "detection" not in a["task"]] # no detection... :(
|
| 190 |
+
self.max_length_seconds = max_length_seconds
|
| 191 |
+
self.cropping = cropping
|
| 192 |
+
self.use_augmentation = use_augmentation
|
| 193 |
+
|
| 194 |
+
### noise augmentation
|
| 195 |
+
self.rng = random.Random(seed)
|
| 196 |
+
self.rngnp = np.random.default_rng(seed=seed)
|
| 197 |
+
self.noise_dirs = noise_dirs
|
| 198 |
+
self.noise_prob = noise_prob
|
| 199 |
+
self.noise_files = []
|
| 200 |
+
self.low_snr = low_snr
|
| 201 |
+
self.high_snr = high_snr
|
| 202 |
+
self.mask_audio_prob = mask_audio_prob
|
| 203 |
+
if noise_dirs is not None and len(self.noise_dirs) > 0 and self.use_augmentation:
|
| 204 |
+
for noise_dir in noise_dirs:
|
| 205 |
+
noise_from_dir = glob.glob(os.path.join(noise_dir, "*.wav"))
|
| 206 |
+
if len(noise_from_dir) < 3000:
|
| 207 |
+
noise_from_dir = noise_from_dir * 3
|
| 208 |
+
print("noise files from dir", noise_dir, len(noise_from_dir))
|
| 209 |
+
self.noise_files.extend(noise_from_dir)
|
| 210 |
+
|
| 211 |
+
### mixup augmentation
|
| 212 |
+
self.mixup_prob = mixup_prob
|
| 213 |
+
self.mixup_count = mixup_count
|
| 214 |
+
# ### time scale augmentation
|
| 215 |
+
self.time_scale = time_scale
|
| 216 |
+
self.time_scale_prob = time_scale_prob
|
| 217 |
+
# tasks = set([annotation["task"] if "task" in annotation else "empty" for annotation in self.annotation])
|
| 218 |
+
print(":::all tasks:::", self.tasks.keys())
|
| 219 |
+
print("num examples", len(self.annotation))
|
| 220 |
+
|
| 221 |
+
def __len__(self):
|
| 222 |
+
return len(self.annotation)
|
| 223 |
+
|
| 224 |
+
def collater(self, samples):
|
| 225 |
+
return collater(samples)
|
| 226 |
+
|
| 227 |
+
def load_audio(self, audio_path, shift_allowed: bool, noise_allowed: bool):
|
| 228 |
+
audio, sr = sf.read(audio_path)
|
| 229 |
+
# assert sr == 16000
|
| 230 |
+
if sr != 16000:
|
| 231 |
+
print("other sr!", sr, audio_path)
|
| 232 |
+
if len(audio.shape) == 2: # stereo to mono
|
| 233 |
+
audio = audio.mean(axis=1)
|
| 234 |
+
|
| 235 |
+
### time scale augmentation
|
| 236 |
+
if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
|
| 237 |
+
# print(f"{index} scaling audio")
|
| 238 |
+
# write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] )
|
| 239 |
+
audio = time_scale(torch.tensor(audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
|
| 240 |
+
# write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] , suffix='_ts')
|
| 241 |
+
|
| 242 |
+
# Randomly crop a max_length_seconds window if audio is longer than 10 seconds
|
| 243 |
+
if len(audio) > sr * self.max_length_seconds and self.cropping == "random":
|
| 244 |
+
max_start = len(audio) - sr * self.max_length_seconds
|
| 245 |
+
start = random.randint(0, max_start)
|
| 246 |
+
audio = audio[start : start + sr * self.max_length_seconds]
|
| 247 |
+
else: # no random cropping
|
| 248 |
+
audio = audio[: sr * self.max_length_seconds] # Truncate audio to at most max_length_seconds
|
| 249 |
+
|
| 250 |
+
### noise augmentation
|
| 251 |
+
audio = torch.tensor(audio)
|
| 252 |
+
### noise augmentation
|
| 253 |
+
if (
|
| 254 |
+
self.use_augmentation
|
| 255 |
+
and self.rng.random() < self.noise_prob
|
| 256 |
+
and len(self.noise_files) > 0
|
| 257 |
+
and noise_allowed
|
| 258 |
+
):
|
| 259 |
+
# write_example_to_file(os.path.basename(ann["path"]), audio)
|
| 260 |
+
# print(f"{index} adding noise")
|
| 261 |
+
noise_file = self.rng.choice(self.noise_files)
|
| 262 |
+
if not os.path.exists(noise_file):
|
| 263 |
+
print(f"Warning: noise file {noise_file} does not exist")
|
| 264 |
+
else:
|
| 265 |
+
noise_audio, noise_sr = sf.read(noise_file)
|
| 266 |
+
assert noise_sr == 16000
|
| 267 |
+
if len(noise_audio.shape) == 2:
|
| 268 |
+
noise_audio = noise_audio.mean(axis=1)
|
| 269 |
+
|
| 270 |
+
noise_audio = torch.tensor(noise_audio)
|
| 271 |
+
|
| 272 |
+
### repeat or trim to the audio size
|
| 273 |
+
if len(audio) > len(noise_audio):
|
| 274 |
+
if len(noise_audio) == 0:
|
| 275 |
+
print(
|
| 276 |
+
"----- Warning: Noise audio length is zero. ---------- ",
|
| 277 |
+
noise_file,
|
| 278 |
+
)
|
| 279 |
+
# Option 1: Skip noise augmentation by setting noise_audio to zero
|
| 280 |
+
noise_audio = torch.zeros_like(audio)
|
| 281 |
+
else:
|
| 282 |
+
nrepeats = int(np.maximum(2, np.ceil(len(audio) / len(noise_audio))))
|
| 283 |
+
noise_audio = noise_audio.repeat(nrepeats)
|
| 284 |
+
### Randomly crop the noise file if it is too long
|
| 285 |
+
if len(noise_audio) > len(audio):
|
| 286 |
+
max_start = len(noise_audio) - len(audio)
|
| 287 |
+
start = random.randint(0, max_start)
|
| 288 |
+
noise_audio = noise_audio[start : start + len(audio)]
|
| 289 |
+
|
| 290 |
+
### remix with specified snr
|
| 291 |
+
snr = self.rngnp.uniform(self.low_snr, self.high_snr)
|
| 292 |
+
snr = torch.tensor([snr])
|
| 293 |
+
noise_audio = snr_scale(audio, noise_audio, snr)
|
| 294 |
+
audio = audio + noise_audio
|
| 295 |
+
|
| 296 |
+
# write_example_to_file(os.path.basename(audio_path), audio, suffix='_noise')
|
| 297 |
+
if len(audio) > self.max_length_seconds * sr:
|
| 298 |
+
print("long audio", len(audio), len(noise_audio))
|
| 299 |
+
audio = audio[: self.max_length_seconds * sr]
|
| 300 |
+
|
| 301 |
+
# pad all audios to max_len_seconds in _getitem_ to ensure no padding inconsistencies.
|
| 302 |
+
if len(audio) < sr * self.max_length_seconds:
|
| 303 |
+
pad_size = sr * self.max_length_seconds - len(audio)
|
| 304 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 305 |
+
|
| 306 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 307 |
+
|
| 308 |
+
return audio
|
| 309 |
+
|
| 310 |
+
def _mix_labels(self, text, text_to_mix):
|
| 311 |
+
"""
|
| 312 |
+
Given two comma-separated label strings (e.g., "gorilla, zebra"),
|
| 313 |
+
combine them without introducing duplicates. If either is "None",
|
| 314 |
+
return the other as-is (unless both are "None").
|
| 315 |
+
"""
|
| 316 |
+
# If `text_to_mix` is explicitly "None", just return `text`.
|
| 317 |
+
if text_to_mix == "None":
|
| 318 |
+
return text
|
| 319 |
+
|
| 320 |
+
# If `text` is explicitly "None", just return `text_to_mix`.
|
| 321 |
+
if text == "None":
|
| 322 |
+
return text_to_mix
|
| 323 |
+
|
| 324 |
+
# Split both strings by comma, stripping whitespace
|
| 325 |
+
text_list = [item.strip() for item in text.split(",") if item.strip()]
|
| 326 |
+
text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
|
| 327 |
+
|
| 328 |
+
# Deduplicate: add only new items from text_to_mix_list
|
| 329 |
+
combined_set = set(text_list)
|
| 330 |
+
for item in text_to_mix_list:
|
| 331 |
+
if item not in combined_set:
|
| 332 |
+
text_list.append(item)
|
| 333 |
+
combined_set.add(item)
|
| 334 |
+
|
| 335 |
+
# If there's nothing left after deduplication, return "None".
|
| 336 |
+
if not text_list:
|
| 337 |
+
return "None"
|
| 338 |
+
|
| 339 |
+
# Rejoin them into a comma-separated string
|
| 340 |
+
return ", ".join(text_list)
|
| 341 |
+
|
| 342 |
+
def _mix_prompts(self, text, text_to_mix, prompt):
|
| 343 |
+
"""
|
| 344 |
+
If the prompt is in the form:
|
| 345 |
+
"Which of these, if any, are present in the audio recording? option1, option2, ..."
|
| 346 |
+
|
| 347 |
+
1. Parse out the question (before '?') and the list of prompt choices (after '?').
|
| 348 |
+
2. Convert both `text` and `text_to_mix` into lists, checking for items not in the prompt.
|
| 349 |
+
3. Append any missing answers to the prompt choices.
|
| 350 |
+
4. Shuffle the choices.
|
| 351 |
+
5. Reassemble and return the new prompt.
|
| 352 |
+
|
| 353 |
+
If the prompt does not follow the expected structure, it is returned unmodified.
|
| 354 |
+
"""
|
| 355 |
+
# Split into two parts: question + choices
|
| 356 |
+
splitted = prompt.split("?")
|
| 357 |
+
if len(splitted) != 2:
|
| 358 |
+
# If we don't have exactly one question mark segment, just return the original prompt
|
| 359 |
+
return prompt
|
| 360 |
+
|
| 361 |
+
question = splitted[0].strip()
|
| 362 |
+
potential_choices_str = splitted[1].strip()
|
| 363 |
+
|
| 364 |
+
# Split the prompt choices
|
| 365 |
+
if not potential_choices_str:
|
| 366 |
+
prompt_choices = []
|
| 367 |
+
else:
|
| 368 |
+
prompt_choices = [c.strip() for c in potential_choices_str.split(",") if c.strip()]
|
| 369 |
+
|
| 370 |
+
# Parse `text`
|
| 371 |
+
text_list = [item.strip() for item in text.split(",") if item.strip()]
|
| 372 |
+
|
| 373 |
+
# Parse `text_to_mix`
|
| 374 |
+
text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
|
| 375 |
+
|
| 376 |
+
# Add any new items from text_list to the prompt
|
| 377 |
+
for item in text_list:
|
| 378 |
+
if item not in prompt_choices:
|
| 379 |
+
prompt_choices.append(item)
|
| 380 |
+
|
| 381 |
+
# Add any new items from text_to_mix_list to the prompt
|
| 382 |
+
for item in text_to_mix_list:
|
| 383 |
+
if item not in prompt_choices:
|
| 384 |
+
prompt_choices.append(item)
|
| 385 |
+
|
| 386 |
+
# Shuffle consistently with self.rng
|
| 387 |
+
self.rng.shuffle(prompt_choices)
|
| 388 |
+
|
| 389 |
+
# Reassemble
|
| 390 |
+
new_prompt = question + "? " + ", ".join(prompt_choices)
|
| 391 |
+
return new_prompt
|
| 392 |
+
|
| 393 |
+
def _apply_mixup(self, prompt, audio, text, task, filename=None):
|
| 394 |
+
# mixup_applied = False
|
| 395 |
+
if (
|
| 396 |
+
self.use_augmentation and self.rng.random() < self.mixup_prob and task in self.mixup_tasks
|
| 397 |
+
# and text != "None" # Allow complex 'None' examples.
|
| 398 |
+
):
|
| 399 |
+
# write_example_to_file(os.path.basename(ann["path"]), audio)
|
| 400 |
+
# print(f"{index} mixing up")
|
| 401 |
+
mixup_indices = []
|
| 402 |
+
for pair_task in self.mixup_tasks[task]:
|
| 403 |
+
mixup_indices.extend(self.tasks[pair_task])
|
| 404 |
+
# mixup_indices = mixup_indices.remove(index)
|
| 405 |
+
|
| 406 |
+
if len(mixup_indices) == 0:
|
| 407 |
+
print("No mixup partner found")
|
| 408 |
+
else:
|
| 409 |
+
### choose n_mixup random partners
|
| 410 |
+
n_mixup = self.rng.randint(1, self.mixup_count)
|
| 411 |
+
mixup_indices = self.rng.sample(mixup_indices, n_mixup)
|
| 412 |
+
# print(f"Mixing up with indices {mixup_indices}")
|
| 413 |
+
for mixup_index in mixup_indices:
|
| 414 |
+
mixup_ann = self.annotation[mixup_index]
|
| 415 |
+
mixup_audio, _ = sf.read(mixup_ann["path"])
|
| 416 |
+
if len(mixup_audio.shape) == 2:
|
| 417 |
+
mixup_audio = mixup_audio.mean(axis=1)
|
| 418 |
+
mixup_audio = mixup_audio[: len(audio)]
|
| 419 |
+
if len(mixup_audio) < len(audio):
|
| 420 |
+
pad_size = len(audio) - len(mixup_audio)
|
| 421 |
+
mixup_audio = np.pad(mixup_audio, (0, pad_size), mode="constant")
|
| 422 |
+
mixup_audio = torch.from_numpy(mixup_audio).float()
|
| 423 |
+
lam = np.clip(self.rngnp.beta(1.0, 1.0), 0.1, 0.8)
|
| 424 |
+
|
| 425 |
+
# Mix the raw_wav
|
| 426 |
+
audio = lam * audio + (1 - lam) * mixup_audio
|
| 427 |
+
|
| 428 |
+
### Mix the prompts if the labels are given in prompts
|
| 429 |
+
if text in prompt:
|
| 430 |
+
prompt = self._mix_prompts(text, mixup_ann["text"], prompt)
|
| 431 |
+
|
| 432 |
+
### Mix the labels
|
| 433 |
+
text = self._mix_labels(text, mixup_ann["text"])
|
| 434 |
+
|
| 435 |
+
# mixup_applied = True
|
| 436 |
+
|
| 437 |
+
# DEBUG: If mixup was actually applied, save the final audio
|
| 438 |
+
# if mixup_applied and filename is not None:
|
| 439 |
+
# # Just add a suffix to the original filename to indicate mixup
|
| 440 |
+
# base_filename = os.path.basename(filename)
|
| 441 |
+
# write_example_to_file(
|
| 442 |
+
# base_filename=base_filename,
|
| 443 |
+
# audio=audio,
|
| 444 |
+
# sr=16000,
|
| 445 |
+
# suffix="_mixup",
|
| 446 |
+
# save_dir="mixup_outputs"
|
| 447 |
+
# )
|
| 448 |
+
# print(f"mixup for {filename}::: prompt {prompt} label {text}")
|
| 449 |
+
|
| 450 |
+
return prompt, audio, text
|
| 451 |
+
|
| 452 |
+
def _load_noise(self, shift_allowed: bool):
|
| 453 |
+
noise_file = self.rng.choice(self.noise_files)
|
| 454 |
+
noise_audio, noise_sr = sf.read(noise_file)
|
| 455 |
+
assert noise_sr == 16000, f"Expected noise sample rate 16000, got {noise_sr}"
|
| 456 |
+
if len(noise_audio.shape) == 2:
|
| 457 |
+
noise_audio = noise_audio.mean(axis=1)
|
| 458 |
+
|
| 459 |
+
# Time scale augmentation if applicable
|
| 460 |
+
if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
|
| 461 |
+
noise_audio = time_scale(torch.tensor(noise_audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
|
| 462 |
+
|
| 463 |
+
# Randomly crop or pad to match max_length_seconds
|
| 464 |
+
if len(noise_audio) > self.max_length_seconds * 16000 and self.cropping == "random":
|
| 465 |
+
max_start = len(noise_audio) - self.max_length_seconds * 16000
|
| 466 |
+
start = random.randint(0, max_start)
|
| 467 |
+
noise_audio = noise_audio[start : start + self.max_length_seconds * 16000]
|
| 468 |
+
else:
|
| 469 |
+
noise_audio = noise_audio[: self.max_length_seconds * 16000]
|
| 470 |
+
|
| 471 |
+
# Pad if needed
|
| 472 |
+
if len(noise_audio) < self.max_length_seconds * 16000:
|
| 473 |
+
pad_size = self.max_length_seconds * 16000 - len(noise_audio)
|
| 474 |
+
noise_audio = np.pad(noise_audio, (0, pad_size), mode="constant")
|
| 475 |
+
|
| 476 |
+
noise_audio = torch.tensor(noise_audio).float()
|
| 477 |
+
noise_audio = torch.clamp(noise_audio, -1.0, 1.0)
|
| 478 |
+
return noise_audio
|
| 479 |
+
|
| 480 |
+
def __getitem__(self, index):
|
| 481 |
+
ann = self.annotation[index]
|
| 482 |
+
# print("loading audio::", ann)
|
| 483 |
+
shift_allowed = "pitch" not in ann.get("task", "")
|
| 484 |
+
noise_allowed = (
|
| 485 |
+
"/A/" not in ann.get("path", "")
|
| 486 |
+
and "-qa" not in ann.get("task", "")
|
| 487 |
+
and "icl" not in ann.get("task", "")
|
| 488 |
+
and "caption" not in ann.get("task", "")
|
| 489 |
+
and "animal-instructions" not in ann.get("task", "")
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
task = ann.get("task", "asr")
|
| 493 |
+
text = ann["text"]
|
| 494 |
+
prompt = ann["prompt"]
|
| 495 |
+
|
| 496 |
+
replace_with_noise = (
|
| 497 |
+
self.use_augmentation
|
| 498 |
+
and task.endswith("detection")
|
| 499 |
+
and self.rng.random() < self.mask_audio_prob
|
| 500 |
+
and len(self.noise_files) > 0
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
if replace_with_noise:
|
| 504 |
+
# Replace audio with noise
|
| 505 |
+
audio = self._load_noise(shift_allowed)
|
| 506 |
+
audios = [audio]
|
| 507 |
+
text = "None"
|
| 508 |
+
|
| 509 |
+
else:
|
| 510 |
+
if "path" in ann and ann["path"] is not None:
|
| 511 |
+
audio = self.load_audio(ann["path"], shift_allowed, noise_allowed)
|
| 512 |
+
audios = [audio]
|
| 513 |
+
else:
|
| 514 |
+
audios = [self.load_audio(p, shift_allowed, noise_allowed) for p in ann["files"]]
|
| 515 |
+
|
| 516 |
+
if len(audios) == 1:
|
| 517 |
+
prompt, mixed_audio, text = self._apply_mixup(prompt, audio, text, task, filename=ann["path"])
|
| 518 |
+
audios = [mixed_audio]
|
| 519 |
+
|
| 520 |
+
return {
|
| 521 |
+
"raw_wav": audios,
|
| 522 |
+
"text": text,
|
| 523 |
+
"task": task,
|
| 524 |
+
"id": ann.get("path") or ";".join(ann["files"]),
|
| 525 |
+
"prompt": prompt,
|
| 526 |
+
"index": index, # track which element for eval output
|
| 527 |
+
"ann": ann, # Include annotation for mixup
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
if __name__ == "__main__":
|
| 532 |
+
dataset = NatureLMDataset(
|
| 533 |
+
ann_path="/home/ubuntu/foundation-model-storage/foundation-model-data/data/compiled-datasets/v1/s2_eval_valid.jsonl",
|
| 534 |
+
noise_dirs=["resource/audio_demo"],
|
| 535 |
+
max_length_seconds=10,
|
| 536 |
+
use_augmentation=True,
|
| 537 |
+
mixup_prob=1.0, # For demonstration, force mixup if possible
|
| 538 |
+
mixup_count=2, # Up to 2 mixup partners
|
| 539 |
+
mask_audio_prob=0.2,
|
| 540 |
+
seed=42,
|
| 541 |
+
noise_prob=0.5,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# Process just a few to see the saved mixups
|
| 545 |
+
for i in range(300):
|
| 546 |
+
sample = dataset[i]
|
| 547 |
+
# print("Final text:", sample["text"])
|
| 548 |
+
# print("Final prompt:", sample["prompt"])
|
| 549 |
+
# print("-" * 40)
|
| 550 |
+
print("Done! Look in 'debug_outputs' folder for saved mixup files.")
|
NatureLM/dist_utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from salesforce@LAVIS. Below is the original copyright:
|
| 3 |
+
Copyright (c) 2022, salesforce.com, inc.
|
| 4 |
+
All rights reserved.
|
| 5 |
+
SPDX-License-Identifier: BSD-3-Clause
|
| 6 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import datetime
|
| 10 |
+
import functools
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def setup_for_distributed(is_master):
|
| 18 |
+
"""
|
| 19 |
+
This function disables printing when not in master process
|
| 20 |
+
"""
|
| 21 |
+
import builtins as __builtin__
|
| 22 |
+
|
| 23 |
+
builtin_print = __builtin__.print
|
| 24 |
+
|
| 25 |
+
def print(*args, **kwargs):
|
| 26 |
+
force = kwargs.pop("force", False)
|
| 27 |
+
if is_master or force:
|
| 28 |
+
builtin_print(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
__builtin__.print = print
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_dist_avail_and_initialized():
|
| 34 |
+
if not dist.is_available():
|
| 35 |
+
return False
|
| 36 |
+
if not dist.is_initialized():
|
| 37 |
+
return False
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_world_size():
|
| 42 |
+
if not is_dist_avail_and_initialized():
|
| 43 |
+
return 1
|
| 44 |
+
return dist.get_world_size()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_rank():
|
| 48 |
+
if not is_dist_avail_and_initialized():
|
| 49 |
+
return 0
|
| 50 |
+
return dist.get_rank()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def is_main_process():
|
| 54 |
+
return get_rank() == 0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def init_distributed_mode(args):
|
| 58 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 59 |
+
args.rank = int(os.environ["RANK"])
|
| 60 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
| 61 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
| 62 |
+
elif "SLURM_PROCID" in os.environ:
|
| 63 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
| 64 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 65 |
+
else:
|
| 66 |
+
print("Not using distributed mode")
|
| 67 |
+
args.use_distributed = False
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
args.use_distributed = True
|
| 71 |
+
|
| 72 |
+
torch.cuda.set_device(args.gpu)
|
| 73 |
+
print(
|
| 74 |
+
"| distributed init (rank {}, world {}): {}".format(args.rank, args.world_size, args.dist_url),
|
| 75 |
+
flush=True,
|
| 76 |
+
)
|
| 77 |
+
torch.distributed.init_process_group(
|
| 78 |
+
backend=args.dist_backend,
|
| 79 |
+
init_method=args.dist_url,
|
| 80 |
+
world_size=args.world_size,
|
| 81 |
+
rank=args.rank,
|
| 82 |
+
timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing
|
| 83 |
+
)
|
| 84 |
+
torch.distributed.barrier()
|
| 85 |
+
setup_for_distributed(args.rank == 0)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_dist_info():
|
| 89 |
+
if torch.__version__ < "1.0":
|
| 90 |
+
initialized = dist._initialized
|
| 91 |
+
else:
|
| 92 |
+
initialized = dist.is_initialized()
|
| 93 |
+
if initialized:
|
| 94 |
+
rank = dist.get_rank()
|
| 95 |
+
world_size = dist.get_world_size()
|
| 96 |
+
else: # non-distributed training
|
| 97 |
+
rank = 0
|
| 98 |
+
world_size = 1
|
| 99 |
+
return rank, world_size
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main_process(func):
|
| 103 |
+
@functools.wraps(func)
|
| 104 |
+
def wrapper(*args, **kwargs):
|
| 105 |
+
rank, _ = get_dist_info()
|
| 106 |
+
if rank == 0:
|
| 107 |
+
return func(*args, **kwargs)
|
| 108 |
+
|
| 109 |
+
return wrapper
|
NatureLM/infer.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run NatureLM-audio over a set of audio files paths or a directory with audio files."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from NatureLM.config import Config
|
| 12 |
+
from NatureLM.models import NatureLM
|
| 13 |
+
from NatureLM.processors import NatureLMAudioProcessor
|
| 14 |
+
from NatureLM.utils import move_to_device
|
| 15 |
+
|
| 16 |
+
_MAX_LENGTH_SECONDS = 10
|
| 17 |
+
_MIN_CHUNK_LENGTH_SECONDS = 0.5
|
| 18 |
+
_SAMPLE_RATE = 16000 # Assuming the model uses a sample rate of 16kHz
|
| 19 |
+
_AUDIO_FILE_EXTENSIONS = [".wav", ".mp3", ".flac", ".ogg"] # Add other audio file formats as needed
|
| 20 |
+
_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
__this_dir = Path(__file__).parent.parent
|
| 22 |
+
_DEFAULT_CONFIG_PATH = __this_dir / "configs" / "inference.yml"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_model_and_config(
|
| 26 |
+
cfg_path: str | Path = _DEFAULT_CONFIG_PATH, device: str = _DEVICE
|
| 27 |
+
) -> tuple[NatureLM, Config]:
|
| 28 |
+
"""Load the NatureLM model and configuration.
|
| 29 |
+
Returns:
|
| 30 |
+
tuple: The loaded model and configuration.
|
| 31 |
+
"""
|
| 32 |
+
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 33 |
+
model = model.to(device).eval()
|
| 34 |
+
model.llama_tokenizer.pad_token_id = model.llama_tokenizer.eos_token_id
|
| 35 |
+
model.llama_model.generation_config.pad_token_id = model.llama_tokenizer.pad_token_id
|
| 36 |
+
|
| 37 |
+
cfg = Config.from_sources(cfg_path)
|
| 38 |
+
return model, cfg
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def output_template(model_output: str, start_time: float, end_time: float) -> str:
|
| 42 |
+
"""Format the output of the model."""
|
| 43 |
+
return f"#{start_time:.2f}s - {end_time:.2f}s#: {model_output}\n"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def sliding_window_inference(
|
| 47 |
+
audio: str | Path | np.ndarray,
|
| 48 |
+
query: str,
|
| 49 |
+
processor: NatureLMAudioProcessor,
|
| 50 |
+
model: NatureLM,
|
| 51 |
+
cfg: Config,
|
| 52 |
+
window_length_seconds: float = 10.0,
|
| 53 |
+
hop_length_seconds: float = 10.0,
|
| 54 |
+
input_sr: int = _SAMPLE_RATE,
|
| 55 |
+
device: str = _DEVICE,
|
| 56 |
+
) -> str:
|
| 57 |
+
"""Run inference on a long audio file using sliding window approach.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
audio (str | Path | np.ndarray): Path to the audio file.
|
| 61 |
+
query (str): Query for the model.
|
| 62 |
+
processor (NatureLMAudioProcessor): Audio processor.
|
| 63 |
+
model (NatureLM): NatureLM model.
|
| 64 |
+
cfg (Config): Model configuration.
|
| 65 |
+
window_length_seconds (float): Length of the sliding window in seconds.
|
| 66 |
+
hop_length_seconds (float): Hop length for the sliding window in seconds.
|
| 67 |
+
input_sr (int): Sample rate of the audio file.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
str: The output of the model.
|
| 71 |
+
|
| 72 |
+
Raises:
|
| 73 |
+
ValueError: If the audio file is too short or if the audio file path is invalid.
|
| 74 |
+
"""
|
| 75 |
+
if isinstance(audio, str) or isinstance(audio, Path):
|
| 76 |
+
audio_array, input_sr = sf.read(str(audio))
|
| 77 |
+
elif isinstance(audio, np.ndarray):
|
| 78 |
+
audio_array = audio
|
| 79 |
+
print(f"Using provided sample rate: {input_sr}")
|
| 80 |
+
|
| 81 |
+
audio_array = audio_array.squeeze()
|
| 82 |
+
if audio_array.ndim > 1:
|
| 83 |
+
axis_to_average = int(np.argmin(audio_array.shape))
|
| 84 |
+
audio_array = audio_array.mean(axis=axis_to_average)
|
| 85 |
+
audio_array = audio_array.squeeze()
|
| 86 |
+
|
| 87 |
+
# Do initial check that the audio is long enough
|
| 88 |
+
if audio_array.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
|
| 89 |
+
raise ValueError(f"Audio is too short. Minimum length is {_MIN_CHUNK_LENGTH_SECONDS} seconds.")
|
| 90 |
+
|
| 91 |
+
start = 0
|
| 92 |
+
stride = int(hop_length_seconds * input_sr)
|
| 93 |
+
window_length = int(window_length_seconds * input_sr)
|
| 94 |
+
|
| 95 |
+
output = ""
|
| 96 |
+
while True:
|
| 97 |
+
chunk = audio_array[start : start + window_length]
|
| 98 |
+
if chunk.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
# Resamples, pads, truncates and creates torch Tensor
|
| 102 |
+
audio_tensor, prompt_list = processor([chunk], [query], [input_sr])
|
| 103 |
+
|
| 104 |
+
input_to_model = {
|
| 105 |
+
"raw_wav": audio_tensor,
|
| 106 |
+
"prompt": prompt_list[0],
|
| 107 |
+
"audio_chunk_sizes": 1,
|
| 108 |
+
"padding_mask": torch.zeros_like(audio_tensor).to(torch.bool),
|
| 109 |
+
}
|
| 110 |
+
input_to_model = move_to_device(input_to_model, device)
|
| 111 |
+
|
| 112 |
+
# generate
|
| 113 |
+
prediction: str = model.generate(input_to_model, cfg.generate, prompt_list)[0]
|
| 114 |
+
|
| 115 |
+
# Post-process the prediction
|
| 116 |
+
prediction = output_template(prediction, start / input_sr, (start + window_length) / input_sr)
|
| 117 |
+
output += prediction
|
| 118 |
+
|
| 119 |
+
# Move the window
|
| 120 |
+
start += stride
|
| 121 |
+
|
| 122 |
+
if start + window_length > audio_array.shape[-1]:
|
| 123 |
+
break
|
| 124 |
+
|
| 125 |
+
return output
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class Pipeline:
|
| 129 |
+
"""Pipeline for running NatureLM-audio inference on a list of audio files or audio arrays"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, model: NatureLM = None, cfg_path: str | Path = _DEFAULT_CONFIG_PATH):
|
| 132 |
+
self.cfg_path = cfg_path
|
| 133 |
+
|
| 134 |
+
# Load model and config
|
| 135 |
+
if model is not None:
|
| 136 |
+
self.cfg = Config.from_sources(cfg_path)
|
| 137 |
+
self.model = model
|
| 138 |
+
else:
|
| 139 |
+
# Download model from hub
|
| 140 |
+
self.model, self.cfg = load_model_and_config(cfg_path)
|
| 141 |
+
|
| 142 |
+
self.processor = NatureLMAudioProcessor(sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS)
|
| 143 |
+
|
| 144 |
+
def __call__(
|
| 145 |
+
self,
|
| 146 |
+
audios: list[str | Path | np.ndarray],
|
| 147 |
+
queries: str | list[str],
|
| 148 |
+
window_length_seconds: float = 10.0,
|
| 149 |
+
hop_length_seconds: float = 10.0,
|
| 150 |
+
input_sample_rate: int = _SAMPLE_RATE,
|
| 151 |
+
verbose: bool = False,
|
| 152 |
+
) -> list[str]:
|
| 153 |
+
"""Run inference on a list of audio file paths or a single audio file with a
|
| 154 |
+
single query or a list of queries. If multiple queries are provided,
|
| 155 |
+
we assume that they are in the same order as the audio files. If a single query
|
| 156 |
+
is provided, it will be used for all audio files.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
audios (list[str | Path | np.ndarray]): List of audio file paths or a single audio file path or audio array(s)
|
| 160 |
+
queries (str | list[str]): Queries for the model.
|
| 161 |
+
window_length_seconds (float): Length of the sliding window in seconds. Defaults to 10.0.
|
| 162 |
+
hop_length_seconds (float): Hop length for the sliding window in seconds. Defaults to 10.0.
|
| 163 |
+
input_sample_rate (int): Sample rate of the audio. Defaults to 16000, which is the model's sample rate.
|
| 164 |
+
verbose (bool): If True, print the output of the model for each audio file.
|
| 165 |
+
Defaults to False.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
str | list[str]: The output of the model..
|
| 169 |
+
|
| 170 |
+
Raises:
|
| 171 |
+
ValueError: If the number of audio files and queries do not match.
|
| 172 |
+
|
| 173 |
+
Example:
|
| 174 |
+
>>> pipeline = Pipeline()
|
| 175 |
+
>>> audios = ["assets/nri-GreenTreeFrogEvergladesNP.mp3"]
|
| 176 |
+
>>> queries = ["Which species is this? Provide the common name."]
|
| 177 |
+
>>> results = pipeline(audios, queries)
|
| 178 |
+
>>> print(results)
|
| 179 |
+
['#0.00s - 10.00s#: Green Treefrog\n']
|
| 180 |
+
"""
|
| 181 |
+
if isinstance(audios, str) or isinstance(audios, Path):
|
| 182 |
+
audios = [audios]
|
| 183 |
+
|
| 184 |
+
if isinstance(queries, str):
|
| 185 |
+
queries = [queries] * len(audios)
|
| 186 |
+
|
| 187 |
+
if len(audios) != len(queries):
|
| 188 |
+
raise ValueError("Number of audio files and queries must match.")
|
| 189 |
+
|
| 190 |
+
# Run inference
|
| 191 |
+
results = []
|
| 192 |
+
for audio, query in zip(audios, queries):
|
| 193 |
+
output = sliding_window_inference(
|
| 194 |
+
audio,
|
| 195 |
+
query,
|
| 196 |
+
self.processor,
|
| 197 |
+
self.model,
|
| 198 |
+
self.cfg,
|
| 199 |
+
window_length_seconds,
|
| 200 |
+
hop_length_seconds,
|
| 201 |
+
input_sr=input_sample_rate,
|
| 202 |
+
)
|
| 203 |
+
results.append(output)
|
| 204 |
+
if verbose:
|
| 205 |
+
print(f"Processed {audio}, model output:\n=======\n{output}\n=======")
|
| 206 |
+
return results
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def parse_args() -> argparse.Namespace:
|
| 210 |
+
parser = argparse.ArgumentParser("Run NatureLM-audio inference")
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"-a", "--audio", type=str, required=True, help="Path to an audio file or a directory containing audio files"
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument("-q", "--query", type=str, required=True, help="Query for the model")
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--cfg-path",
|
| 217 |
+
type=str,
|
| 218 |
+
default="configs/inference.yml",
|
| 219 |
+
help="Path to the configuration file for the model",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument("--output_path", type=str, default="inference_output.jsonl", help="Output path for the results")
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--window_length_seconds", type=float, default=10.0, help="Length of the sliding window in seconds"
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--hop_length_seconds", type=float, default=10.0, help="Hop length for the sliding window in seconds"
|
| 227 |
+
)
|
| 228 |
+
args = parser.parse_args()
|
| 229 |
+
|
| 230 |
+
return args
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def main(
|
| 234 |
+
cfg_path: str | Path,
|
| 235 |
+
audio_path: str | Path,
|
| 236 |
+
query: str,
|
| 237 |
+
output_path: str,
|
| 238 |
+
window_length_seconds: float,
|
| 239 |
+
hop_length_seconds: float,
|
| 240 |
+
) -> None:
|
| 241 |
+
"""Main function to run the NatureLM-audio inference script.
|
| 242 |
+
It takes command line arguments for audio file path, query, output path,
|
| 243 |
+
window length, and hop length. It processes the audio files and saves the
|
| 244 |
+
results to a CSV file.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
cfg_path (str | Path): Path to the configuration file.
|
| 248 |
+
audio_path (str | Path): Path to the audio file or directory.
|
| 249 |
+
query (str): Query for the model.
|
| 250 |
+
output_path (str): Path to save the output results.
|
| 251 |
+
window_length_seconds (float): Length of the sliding window in seconds.
|
| 252 |
+
hop_length_seconds (float): Hop length for the sliding window in seconds.
|
| 253 |
+
|
| 254 |
+
Raises:
|
| 255 |
+
ValueError: If the audio file path is invalid or if the query is empty.
|
| 256 |
+
ValueError: If no audio files are found.
|
| 257 |
+
ValueError: If the audio file extension is not supported.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
# Prepare sample
|
| 261 |
+
audio_path = Path(audio_path)
|
| 262 |
+
if audio_path.is_dir():
|
| 263 |
+
audio_paths = []
|
| 264 |
+
print(f"Searching for audio files in {str(audio_path)} with extensions {', '.join(_AUDIO_FILE_EXTENSIONS)}")
|
| 265 |
+
for ext in _AUDIO_FILE_EXTENSIONS:
|
| 266 |
+
audio_paths.extend(list(audio_path.rglob(f"*{ext}")))
|
| 267 |
+
|
| 268 |
+
print(f"Found {len(audio_paths)} audio files in {str(audio_path)}")
|
| 269 |
+
else:
|
| 270 |
+
# check that the extension is valid
|
| 271 |
+
if not any(audio_path.suffix == ext for ext in _AUDIO_FILE_EXTENSIONS):
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"Invalid audio file extension. Supported extensions are: {', '.join(_AUDIO_FILE_EXTENSIONS)}"
|
| 274 |
+
)
|
| 275 |
+
audio_paths = [audio_path]
|
| 276 |
+
|
| 277 |
+
# check that query is not empty
|
| 278 |
+
if not query:
|
| 279 |
+
raise ValueError("Query cannot be empty")
|
| 280 |
+
if not audio_paths:
|
| 281 |
+
raise ValueError("No audio files found. Please check the path or file extensions.")
|
| 282 |
+
|
| 283 |
+
# Load model and config
|
| 284 |
+
model, cfg = load_model_and_config(cfg_path)
|
| 285 |
+
|
| 286 |
+
# Load audio processor
|
| 287 |
+
processor = NatureLMAudioProcessor(sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS)
|
| 288 |
+
|
| 289 |
+
# Run inference
|
| 290 |
+
results = {"audio_path": [], "output": []}
|
| 291 |
+
for path in audio_paths:
|
| 292 |
+
output = sliding_window_inference(path, query, processor, model, cfg, window_length_seconds, hop_length_seconds)
|
| 293 |
+
results["audio_path"].append(str(path))
|
| 294 |
+
results["output"].append(output)
|
| 295 |
+
print(f"Processed {path}, model output:\n=======\n{output}\n=======\n")
|
| 296 |
+
|
| 297 |
+
# Save results as a csv
|
| 298 |
+
output_path = Path(output_path)
|
| 299 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 300 |
+
|
| 301 |
+
df = pd.DataFrame(results)
|
| 302 |
+
df.to_json(output_path, orient="records", lines=True)
|
| 303 |
+
print(f"Results saved to {output_path}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
args = parse_args()
|
| 308 |
+
main(
|
| 309 |
+
cfg_path=args.cfg_path,
|
| 310 |
+
audio_path=args.audio,
|
| 311 |
+
query=args.query,
|
| 312 |
+
output_path=args.output_path,
|
| 313 |
+
window_length_seconds=args.window_length_seconds,
|
| 314 |
+
hop_length_seconds=args.hop_length_seconds,
|
| 315 |
+
)
|
NatureLM/logger.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
from collections import defaultdict, deque
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import wandb
|
| 9 |
+
|
| 10 |
+
from NatureLM.dist_utils import is_dist_avail_and_initialized, is_main_process
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SmoothedValue(object):
|
| 14 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 15 |
+
window or the global series average.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, window_size=20, fmt=None):
|
| 19 |
+
if fmt is None:
|
| 20 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 21 |
+
self.deque = deque(maxlen=window_size)
|
| 22 |
+
self.total = 0.0
|
| 23 |
+
self.count = 0
|
| 24 |
+
self.fmt = fmt
|
| 25 |
+
|
| 26 |
+
def update(self, value, n=1):
|
| 27 |
+
self.deque.append(value)
|
| 28 |
+
self.count += n
|
| 29 |
+
self.total += value * n
|
| 30 |
+
|
| 31 |
+
def synchronize_between_processes(self):
|
| 32 |
+
"""
|
| 33 |
+
Warning: does not synchronize the deque!
|
| 34 |
+
"""
|
| 35 |
+
if not is_dist_avail_and_initialized():
|
| 36 |
+
return
|
| 37 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
| 38 |
+
dist.barrier()
|
| 39 |
+
dist.all_reduce(t)
|
| 40 |
+
t = t.tolist()
|
| 41 |
+
self.count = int(t[0])
|
| 42 |
+
self.total = t[1]
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def median(self):
|
| 46 |
+
d = torch.tensor(list(self.deque))
|
| 47 |
+
return d.median().item()
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def avg(self):
|
| 51 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 52 |
+
return d.mean().item()
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def global_avg(self):
|
| 56 |
+
return self.total / self.count
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def max(self):
|
| 60 |
+
return max(self.deque)
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def value(self):
|
| 64 |
+
return self.deque[-1]
|
| 65 |
+
|
| 66 |
+
def __str__(self):
|
| 67 |
+
return self.fmt.format(
|
| 68 |
+
median=self.median,
|
| 69 |
+
avg=self.avg,
|
| 70 |
+
global_avg=self.global_avg,
|
| 71 |
+
max=self.max,
|
| 72 |
+
value=self.value,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MetricLogger(object):
|
| 77 |
+
def __init__(self, delimiter="\t"):
|
| 78 |
+
self.meters = defaultdict(SmoothedValue)
|
| 79 |
+
self.delimiter = delimiter
|
| 80 |
+
|
| 81 |
+
def update(self, **kwargs):
|
| 82 |
+
for k, v in kwargs.items():
|
| 83 |
+
if isinstance(v, torch.Tensor):
|
| 84 |
+
v = v.item()
|
| 85 |
+
assert isinstance(v, (float, int))
|
| 86 |
+
self.meters[k].update(v)
|
| 87 |
+
|
| 88 |
+
def __getattr__(self, attr):
|
| 89 |
+
if attr in self.meters:
|
| 90 |
+
return self.meters[attr]
|
| 91 |
+
if attr in self.__dict__:
|
| 92 |
+
return self.__dict__[attr]
|
| 93 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
| 94 |
+
|
| 95 |
+
def __str__(self):
|
| 96 |
+
loss_str = []
|
| 97 |
+
for name, meter in self.meters.items():
|
| 98 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
| 99 |
+
return self.delimiter.join(loss_str)
|
| 100 |
+
|
| 101 |
+
def global_avg(self):
|
| 102 |
+
loss_str = []
|
| 103 |
+
for name, meter in self.meters.items():
|
| 104 |
+
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
| 105 |
+
return self.delimiter.join(loss_str)
|
| 106 |
+
|
| 107 |
+
def synchronize_between_processes(self):
|
| 108 |
+
for meter in self.meters.values():
|
| 109 |
+
meter.synchronize_between_processes()
|
| 110 |
+
|
| 111 |
+
def add_meter(self, name, meter):
|
| 112 |
+
self.meters[name] = meter
|
| 113 |
+
|
| 114 |
+
def log_every(self, iterable, print_freq, header=None, logger=None, start_step=None):
|
| 115 |
+
i = 0
|
| 116 |
+
if not header:
|
| 117 |
+
header = ""
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
end = time.time()
|
| 120 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
| 121 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
| 122 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
| 123 |
+
log_msg = [
|
| 124 |
+
header,
|
| 125 |
+
"[{0" + space_fmt + "}/{1}]",
|
| 126 |
+
"eta: {eta}",
|
| 127 |
+
"{meters}",
|
| 128 |
+
"time: {time}",
|
| 129 |
+
"data: {data}",
|
| 130 |
+
]
|
| 131 |
+
if torch.cuda.is_available():
|
| 132 |
+
log_msg.append("max mem: {memory:.0f}")
|
| 133 |
+
log_msg = self.delimiter.join(log_msg)
|
| 134 |
+
MB = 1024.0 * 1024.0
|
| 135 |
+
for obj in iterable:
|
| 136 |
+
data_time.update(time.time() - end)
|
| 137 |
+
yield obj
|
| 138 |
+
iter_time.update(time.time() - end)
|
| 139 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 140 |
+
if is_main_process():
|
| 141 |
+
if logger is not None:
|
| 142 |
+
assert start_step is not None, "start_step is needed to compute global_step!"
|
| 143 |
+
for name, meter in self.meters.items():
|
| 144 |
+
logger.add_scalar("{}".format(name), float(str(meter)), global_step=start_step + i)
|
| 145 |
+
# Log to wandb
|
| 146 |
+
wandb.log({name: float(str(meter)) for name, meter in self.meters.items()}, step=start_step + i)
|
| 147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 149 |
+
if torch.cuda.is_available():
|
| 150 |
+
print(
|
| 151 |
+
log_msg.format(
|
| 152 |
+
i,
|
| 153 |
+
len(iterable),
|
| 154 |
+
eta=eta_string,
|
| 155 |
+
meters=str(self),
|
| 156 |
+
time=str(iter_time),
|
| 157 |
+
data=str(data_time),
|
| 158 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
print(
|
| 163 |
+
log_msg.format(
|
| 164 |
+
i,
|
| 165 |
+
len(iterable),
|
| 166 |
+
eta=eta_string,
|
| 167 |
+
meters=str(self),
|
| 168 |
+
time=str(iter_time),
|
| 169 |
+
data=str(data_time),
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
i += 1
|
| 173 |
+
end = time.time()
|
| 174 |
+
total_time = time.time() - start_time
|
| 175 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 176 |
+
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class AttrDict(dict):
|
| 180 |
+
def __init__(self, *args, **kwargs):
|
| 181 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 182 |
+
self.__dict__ = self
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def setup_logger():
|
| 186 |
+
logging.basicConfig(
|
| 187 |
+
level=logging.INFO if is_main_process() else logging.WARN,
|
| 188 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 189 |
+
handlers=[logging.StreamHandler()],
|
| 190 |
+
)
|
NatureLM/models/NatureLM.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Earth Species Project
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Literal, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 24 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
|
| 28 |
+
|
| 29 |
+
from NatureLM.checkpoint_utils import save_model_checkpoint
|
| 30 |
+
from NatureLM.config import BeatsConfig, ModelConfig, save_config_as_yaml
|
| 31 |
+
from NatureLM.utils import universal_torch_load
|
| 32 |
+
|
| 33 |
+
from .beats.BEATs import BEATs, BEATsConfig
|
| 34 |
+
from .Qformer import BertConfig, BertLMHeadModel
|
| 35 |
+
from .utils import StoppingCriteriaSub
|
| 36 |
+
|
| 37 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 38 |
+
|
| 39 |
+
auth_token = os.getenv('llama')
|
| 40 |
+
|
| 41 |
+
class NatureLM(nn.Module, PyTorchModelHubMixin):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
*,
|
| 45 |
+
llama_path: Path,
|
| 46 |
+
beats_path: Path | os.PathLike | None = None,
|
| 47 |
+
beats_cfg: BeatsConfig,
|
| 48 |
+
freeze_beats: bool = True,
|
| 49 |
+
use_audio_Qformer: bool = True,
|
| 50 |
+
max_pooling: bool = False,
|
| 51 |
+
num_audio_query_token: int = 1,
|
| 52 |
+
freeze_audio_QFormer: bool = False,
|
| 53 |
+
window_level_Qformer: bool = True,
|
| 54 |
+
second_per_window: float = 0.333333,
|
| 55 |
+
second_stride: float = 0.333333,
|
| 56 |
+
downsample_factor: int = 4,
|
| 57 |
+
audio_llama_proj_model: Path | os.PathLike | None = None,
|
| 58 |
+
freeze_audio_llama_proj: bool = False,
|
| 59 |
+
lora: bool = True,
|
| 60 |
+
lora_rank: int = 8,
|
| 61 |
+
lora_alpha: int = 32,
|
| 62 |
+
lora_dropout: float = 0.1,
|
| 63 |
+
flash_attn: Literal["eager", "flash_attention_2"] = "eager",
|
| 64 |
+
prompt_template: str = "",
|
| 65 |
+
max_txt_len: int = 128,
|
| 66 |
+
end_sym: str = "</s>",
|
| 67 |
+
device: str = "cuda",
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.beats_path = beats_path
|
| 72 |
+
self.beats_cfg = beats_cfg
|
| 73 |
+
self.use_audio_Qformer = use_audio_Qformer
|
| 74 |
+
self.max_pooling = max_pooling
|
| 75 |
+
self.window_level_Qformer = window_level_Qformer
|
| 76 |
+
self.second_per_window = second_per_window
|
| 77 |
+
self.second_stride = second_stride
|
| 78 |
+
self.downsample_factor = downsample_factor
|
| 79 |
+
self.lora = lora
|
| 80 |
+
self.max_txt_len = max_txt_len
|
| 81 |
+
self.end_sym = end_sym
|
| 82 |
+
self.prompt_template = prompt_template
|
| 83 |
+
self.flash_attn = flash_attn
|
| 84 |
+
|
| 85 |
+
logging.info(f"Llama path: {llama_path}")
|
| 86 |
+
logging.info("Loading Llama Tokenizer")
|
| 87 |
+
self.llama_tokenizer = AutoTokenizer.from_pretrained(llama_path, use_fast=False, use_auth_token=auth_token)
|
| 88 |
+
self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 89 |
+
self.llama_tokenizer.padding_side = "right"
|
| 90 |
+
|
| 91 |
+
logging.info("Loading Llama Model")
|
| 92 |
+
if device == "cpu":
|
| 93 |
+
self.llama_model = AutoModelForCausalLM.from_pretrained(
|
| 94 |
+
llama_path,
|
| 95 |
+
torch_dtype=torch.float32,
|
| 96 |
+
attn_implementation="eager",
|
| 97 |
+
device_map="cpu",
|
| 98 |
+
use_auth_token=auth_token
|
| 99 |
+
)
|
| 100 |
+
# An issue with tiny-llama is that pad_token_id was set to -1, but
|
| 101 |
+
# model.save_pretrained checks generation configs and does not allow -1 as
|
| 102 |
+
# pad_token_id
|
| 103 |
+
self.llama_model.generation_config.pad_token_id = self.llama_tokenizer.pad_token_id
|
| 104 |
+
else:
|
| 105 |
+
self.llama_model = AutoModelForCausalLM.from_pretrained(
|
| 106 |
+
llama_path,
|
| 107 |
+
torch_dtype=torch.bfloat16,
|
| 108 |
+
attn_implementation=flash_attn,
|
| 109 |
+
use_auth_token=auth_token
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
|
| 113 |
+
if self.lora:
|
| 114 |
+
for param in self.llama_model.parameters():
|
| 115 |
+
param.requires_grad = False
|
| 116 |
+
logging.info("Loading LLaMA Done")
|
| 117 |
+
self.llama_embed_tokens = self.llama_model.model.embed_tokens
|
| 118 |
+
|
| 119 |
+
if self.lora:
|
| 120 |
+
logging.info("Setting up LoRA for llama model")
|
| 121 |
+
self.peft_config = LoraConfig(
|
| 122 |
+
task_type=TaskType.CAUSAL_LM,
|
| 123 |
+
inference_mode=False,
|
| 124 |
+
r=lora_rank,
|
| 125 |
+
lora_alpha=lora_alpha,
|
| 126 |
+
lora_dropout=lora_dropout,
|
| 127 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 128 |
+
)
|
| 129 |
+
self.llama_model = get_peft_model(self.llama_model, self.peft_config)
|
| 130 |
+
self.llama_embed_tokens = self.llama_model.model.model.embed_tokens
|
| 131 |
+
self.llama_model.print_trainable_parameters()
|
| 132 |
+
logging.info("LoRA Training")
|
| 133 |
+
|
| 134 |
+
logging.info("Loading BEATs Model")
|
| 135 |
+
self.beats = BEATs(cfg=BEATsConfig(dict(self.beats_cfg)))
|
| 136 |
+
|
| 137 |
+
if self.beats_path:
|
| 138 |
+
beats_ckpt = universal_torch_load(self.beats_path, cache_mode="none", map_location="cpu")
|
| 139 |
+
self.beats.load_state_dict(beats_ckpt["model"])
|
| 140 |
+
|
| 141 |
+
self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
|
| 142 |
+
if freeze_beats:
|
| 143 |
+
for param in self.beats.parameters():
|
| 144 |
+
param.requires_grad = False
|
| 145 |
+
self.beats.eval()
|
| 146 |
+
logging.info("freeze BEATs")
|
| 147 |
+
|
| 148 |
+
if self.use_audio_Qformer:
|
| 149 |
+
self.audio_Qformer, self.audio_query_tokens = self.init_audio_Qformer(
|
| 150 |
+
num_query_token=num_audio_query_token,
|
| 151 |
+
audio_width=self.beats.cfg.encoder_embed_dim,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
self.audio_Qformer.bert.embeddings.word_embeddings = None
|
| 155 |
+
self.audio_Qformer.bert.embeddings.position_embeddings = None
|
| 156 |
+
for layer in self.audio_Qformer.bert.encoder.layer:
|
| 157 |
+
layer.output = None
|
| 158 |
+
layer.intermediate = None
|
| 159 |
+
self.audio_Qformer.cls = None
|
| 160 |
+
if freeze_audio_QFormer:
|
| 161 |
+
for param in self.audio_Qformer.parameters():
|
| 162 |
+
param.requires_grad = False
|
| 163 |
+
self.audio_Qformer.eval()
|
| 164 |
+
self.audio_query_tokens.requires_grad = False
|
| 165 |
+
logging.info("freeze audio QFormer")
|
| 166 |
+
|
| 167 |
+
logging.info("Loading audio LLAMA proj")
|
| 168 |
+
self.audio_llama_proj = nn.Linear(
|
| 169 |
+
self.audio_Qformer.config.hidden_size,
|
| 170 |
+
self.llama_model.config.hidden_size,
|
| 171 |
+
)
|
| 172 |
+
if audio_llama_proj_model:
|
| 173 |
+
logging.info(f"Loading audio LLAMA proj from {audio_llama_proj_model}")
|
| 174 |
+
# audio_llama_proj_weight = torch.load(audio_llama_proj_model, map_location="cpu")
|
| 175 |
+
audio_llama_proj_weight = universal_torch_load(
|
| 176 |
+
audio_llama_proj_model, cache_mode="use", map_location="cpu"
|
| 177 |
+
)
|
| 178 |
+
self.load_state_dict(audio_llama_proj_weight["model"], strict=False)
|
| 179 |
+
|
| 180 |
+
if freeze_audio_llama_proj:
|
| 181 |
+
for param in self.audio_llama_proj.parameters():
|
| 182 |
+
param.requires_grad = False
|
| 183 |
+
self.audio_llama_proj.eval()
|
| 184 |
+
logging.info("freeze audio LLAMA proj")
|
| 185 |
+
|
| 186 |
+
elif self.max_pooling:
|
| 187 |
+
hidden_size = (
|
| 188 |
+
768
|
| 189 |
+
if self.aves
|
| 190 |
+
else 768
|
| 191 |
+
if self.htsat
|
| 192 |
+
else 1024
|
| 193 |
+
if self.aves_large
|
| 194 |
+
else self.beats.cfg.encoder_embed_dim
|
| 195 |
+
)
|
| 196 |
+
self.audio_llama_proj = nn.Linear(
|
| 197 |
+
hidden_size, self.llama_model.config.hidden_size
|
| 198 |
+
) # Single embedding, just project to LLM.
|
| 199 |
+
|
| 200 |
+
elif self.htsat:
|
| 201 |
+
self.audio_llama_proj = nn.Linear(
|
| 202 |
+
512, self.llama_model.config.hidden_size
|
| 203 |
+
) # Single embedding, just project to LLM.
|
| 204 |
+
|
| 205 |
+
else:
|
| 206 |
+
# feel free to add other aligners here
|
| 207 |
+
raise NotImplementedError("Have to use audio qformer")
|
| 208 |
+
|
| 209 |
+
self.config: ModelConfig = None # set this in from_config
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def from_config(cls, config: ModelConfig):
|
| 213 |
+
model = cls(
|
| 214 |
+
llama_path=config.llama_path,
|
| 215 |
+
beats_path=config.beats_path,
|
| 216 |
+
freeze_beats=config.freeze_beats,
|
| 217 |
+
use_audio_Qformer=config.use_audio_Qformer,
|
| 218 |
+
max_pooling=config.max_pooling,
|
| 219 |
+
num_audio_query_token=config.num_audio_query_token,
|
| 220 |
+
freeze_audio_QFormer=config.freeze_audio_QFormer,
|
| 221 |
+
window_level_Qformer=config.window_level_Qformer,
|
| 222 |
+
second_per_window=config.second_per_window,
|
| 223 |
+
second_stride=config.second_stride,
|
| 224 |
+
downsample_factor=config.downsample_factor,
|
| 225 |
+
audio_llama_proj_model=config.audio_llama_proj_model,
|
| 226 |
+
freeze_audio_llama_proj=config.freeze_audio_llama_proj,
|
| 227 |
+
lora=config.lora,
|
| 228 |
+
lora_rank=config.lora_rank,
|
| 229 |
+
lora_alpha=config.lora_alpha,
|
| 230 |
+
lora_dropout=config.lora_dropout,
|
| 231 |
+
prompt_template=config.prompt_template,
|
| 232 |
+
max_txt_len=config.max_txt_len,
|
| 233 |
+
end_sym=config.end_sym,
|
| 234 |
+
flash_attn=config.flash_attn,
|
| 235 |
+
device=config.device,
|
| 236 |
+
)
|
| 237 |
+
model.config = config
|
| 238 |
+
ckpt_path = config.ckpt
|
| 239 |
+
if ckpt_path:
|
| 240 |
+
logging.info(f"⏳ Load NatureLM ckpt from: {ckpt_path}")
|
| 241 |
+
ckpt = universal_torch_load(ckpt_path, cache_mode="use", map_location="cpu")
|
| 242 |
+
model.load_state_dict(ckpt["model"], strict=False)
|
| 243 |
+
logging.info("✅ Finished loading from ckpt")
|
| 244 |
+
|
| 245 |
+
return model
|
| 246 |
+
|
| 247 |
+
def _save_to_local(
|
| 248 |
+
self,
|
| 249 |
+
output_dir: Union[str, os.PathLike],
|
| 250 |
+
use_distributed: bool = False,
|
| 251 |
+
drop_untrained_params: bool = False,
|
| 252 |
+
) -> None:
|
| 253 |
+
output_dir = Path(output_dir)
|
| 254 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 255 |
+
|
| 256 |
+
# Save the config
|
| 257 |
+
config_path = output_dir / "model_config.yaml"
|
| 258 |
+
save_config_as_yaml(self.config, config_path)
|
| 259 |
+
|
| 260 |
+
# Save the model
|
| 261 |
+
model_path = output_dir / "model.pt"
|
| 262 |
+
save_model_checkpoint(
|
| 263 |
+
self,
|
| 264 |
+
model_path,
|
| 265 |
+
drop_untrained_params=drop_untrained_params,
|
| 266 |
+
use_distributed=use_distributed,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Save the tokenizer and llama model
|
| 270 |
+
tokenizer_path = output_dir / "llama"
|
| 271 |
+
self.llama_tokenizer.save_pretrained(tokenizer_path)
|
| 272 |
+
self.llama_model.save_pretrained(tokenizer_path)
|
| 273 |
+
|
| 274 |
+
# Save the audio model
|
| 275 |
+
if self.beats_path:
|
| 276 |
+
beats_path = output_dir / "beats.pt"
|
| 277 |
+
save_model_checkpoint(
|
| 278 |
+
self.beats,
|
| 279 |
+
beats_path,
|
| 280 |
+
drop_untrained_params=drop_untrained_params,
|
| 281 |
+
cfg=self.beats_cfg,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Save the audio projection
|
| 285 |
+
audio_llama_proj_path = output_dir / "audio_llama_proj.pt"
|
| 286 |
+
save_model_checkpoint(
|
| 287 |
+
self.audio_llama_proj,
|
| 288 |
+
audio_llama_proj_path,
|
| 289 |
+
drop_untrained_params=drop_untrained_params,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def init_audio_Qformer(num_query_token, audio_width, num_hidden_layers=2):
|
| 294 |
+
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
| 295 |
+
encoder_config.num_hidden_layers = num_hidden_layers
|
| 296 |
+
encoder_config.encoder_width = audio_width
|
| 297 |
+
# insert cross-attention layer every other block
|
| 298 |
+
encoder_config.add_cross_attention = True
|
| 299 |
+
encoder_config.cross_attention_freq = 1
|
| 300 |
+
encoder_config.query_length = num_query_token
|
| 301 |
+
Qformer = BertLMHeadModel(config=encoder_config)
|
| 302 |
+
query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
|
| 303 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
| 304 |
+
return Qformer, query_tokens
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def device(self):
|
| 308 |
+
return list(self.parameters())[0].device
|
| 309 |
+
|
| 310 |
+
def _encode_auditory_feature(self, audio_embeds, audio_pad_mask):
|
| 311 |
+
if self.max_pooling:
|
| 312 |
+
# Max Pooling logic to reduce sequence length
|
| 313 |
+
|
| 314 |
+
# Apply 1D Max Pooling along the time dimension
|
| 315 |
+
audio_embeds = F.max_pool1d(
|
| 316 |
+
audio_embeds.transpose(1, 2),
|
| 317 |
+
kernel_size=self.downsample_factor,
|
| 318 |
+
stride=self.downsample_factor,
|
| 319 |
+
).transpose(1, 2)
|
| 320 |
+
audio_embeds = self.audio_llama_proj(audio_embeds)
|
| 321 |
+
|
| 322 |
+
# print("audio pad mask is", audio_pad_mask)
|
| 323 |
+
audio_atts = ~audio_pad_mask
|
| 324 |
+
# Adjust the padding mask using max pooling
|
| 325 |
+
audio_atts = F.max_pool1d(
|
| 326 |
+
audio_atts.unsqueeze(1).float(),
|
| 327 |
+
kernel_size=self.downsample_factor,
|
| 328 |
+
stride=self.downsample_factor,
|
| 329 |
+
).squeeze(1)
|
| 330 |
+
audio_atts = audio_atts > 0
|
| 331 |
+
# print(f"audio pad mask shape after pooling: {audio_atts.shape}")
|
| 332 |
+
# print("audio pad mask post", audio_atts)
|
| 333 |
+
|
| 334 |
+
elif self.use_audio_Qformer:
|
| 335 |
+
# Q-Former logic
|
| 336 |
+
audio_embeds = self.ln_audio(audio_embeds)
|
| 337 |
+
|
| 338 |
+
# Generate attention mask
|
| 339 |
+
audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
|
| 340 |
+
|
| 341 |
+
if self.window_level_Qformer:
|
| 342 |
+
B, T, C = audio_embeds.shape # batch, T, Channels
|
| 343 |
+
kernel = round(1500 * self.second_per_window / 30.0) # 160 ms patches; calculate kernel size
|
| 344 |
+
stride = round(1500 * self.second_stride / 30.0) # Calculate stride size
|
| 345 |
+
kernel = (1, kernel)
|
| 346 |
+
stride = (1, stride)
|
| 347 |
+
|
| 348 |
+
# Transpose and unfold audio embeddings to create overlapping windows
|
| 349 |
+
audio_embeds_tr = audio_embeds.transpose(1, 2).unsqueeze(2)
|
| 350 |
+
audio_embeds_overlap = F.unfold(
|
| 351 |
+
audio_embeds_tr,
|
| 352 |
+
kernel_size=kernel,
|
| 353 |
+
dilation=1,
|
| 354 |
+
padding=0,
|
| 355 |
+
stride=stride,
|
| 356 |
+
)
|
| 357 |
+
_, _, L = audio_embeds_overlap.shape
|
| 358 |
+
audio_embeds_overlap = audio_embeds_overlap.view(B, -1, kernel[1], L)
|
| 359 |
+
audio_embeds_overlap = torch.permute(
|
| 360 |
+
audio_embeds_overlap, [0, 3, 2, 1]
|
| 361 |
+
) # (B, num_windows, kernel_size, C)
|
| 362 |
+
audio_embeds = audio_embeds_overlap.reshape(-1, kernel[1], C)
|
| 363 |
+
audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
|
| 364 |
+
|
| 365 |
+
# Q-Former mechanism
|
| 366 |
+
query_tokens = self.audio_query_tokens.expand(audio_embeds.shape[0], -1, -1)
|
| 367 |
+
query_output = self.audio_Qformer.bert(
|
| 368 |
+
query_embeds=query_tokens,
|
| 369 |
+
encoder_hidden_states=audio_embeds,
|
| 370 |
+
encoder_attention_mask=audio_atts,
|
| 371 |
+
return_dict=True,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
audio_embeds = self.audio_llama_proj(query_output.last_hidden_state)
|
| 375 |
+
|
| 376 |
+
if self.window_level_Qformer:
|
| 377 |
+
audio_embeds = audio_embeds.view(B, -1, audio_embeds.size(2)).contiguous()
|
| 378 |
+
|
| 379 |
+
audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
|
| 380 |
+
|
| 381 |
+
elif self.htsat:
|
| 382 |
+
# HTSAT processing
|
| 383 |
+
audio_embeds = self.ln_audio(audio_embeds)
|
| 384 |
+
audio_embeds = self.audio_llama_proj(audio_embeds).reshape(-1, 30, self.llama_model.config.hidden_size)
|
| 385 |
+
audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
|
| 386 |
+
|
| 387 |
+
else:
|
| 388 |
+
raise NotImplementedError("no audio qformer or max pooling")
|
| 389 |
+
|
| 390 |
+
return audio_embeds, audio_atts
|
| 391 |
+
|
| 392 |
+
def encode_audio(self, raw_wav, audio_padding_mask=None):
|
| 393 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
| 394 |
+
audio_embeds, audio_pad_mask = self.beats(raw_wav, padding_mask=audio_padding_mask)
|
| 395 |
+
return self._encode_auditory_feature(audio_embeds=audio_embeds, audio_pad_mask=audio_pad_mask)
|
| 396 |
+
|
| 397 |
+
def prompt_wrap(self, audio_embeds, audio_atts, prompt: list[str]):
|
| 398 |
+
"""Merge audio embeddings with embeddings of the tokens in the prompt.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
audio_embeds (list): List of tensors of audio embeddings.
|
| 402 |
+
audio_atts (list): List of tensors of audio padding masks.
|
| 403 |
+
prompt (list): List of strings with the prompt for each sample. Each prompt
|
| 404 |
+
should contain the placeholder(s) "<AudioHere>" to indicate where the
|
| 405 |
+
audio embeddings should be inserted.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
tuple: A tuple containing the wrapped audio embeddings and padding masks.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
def interleave_lists(longer: list, shorter: list) -> list:
|
| 412 |
+
"""Interleave two lists where the first list is one element longer.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
longer (list): The first list with length n.
|
| 416 |
+
shorter (list): The second list with length n-1.
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
list: A new list with elements interleaved from longer and shorter.
|
| 420 |
+
|
| 421 |
+
Example:
|
| 422 |
+
>>> interleave_lists(['a1', 'a2', 'a3'], ['b1', 'b2'])
|
| 423 |
+
['a1', 'b1', 'a2', 'b2', 'a3']
|
| 424 |
+
"""
|
| 425 |
+
interleaved_list = []
|
| 426 |
+
for i in range(len(shorter)):
|
| 427 |
+
interleaved_list.append(longer[i])
|
| 428 |
+
interleaved_list.append(shorter[i])
|
| 429 |
+
interleaved_list.append(longer[-1]) # last element is from longer
|
| 430 |
+
return interleaved_list
|
| 431 |
+
|
| 432 |
+
device = audio_embeds[0].device
|
| 433 |
+
|
| 434 |
+
wrapped_embeds_list = []
|
| 435 |
+
wrapped_atts_list = []
|
| 436 |
+
batch_size = len(prompt)
|
| 437 |
+
for i in range(batch_size):
|
| 438 |
+
prompt_parts = prompt[i].split("<AudioHere>")
|
| 439 |
+
wrapped_embeds = []
|
| 440 |
+
wrapped_atts = []
|
| 441 |
+
|
| 442 |
+
for part in prompt_parts:
|
| 443 |
+
tokens = self.llama_tokenizer(part, return_tensors="pt", add_special_tokens=False).to(device)
|
| 444 |
+
part_embeds = self.llama_embed_tokens(tokens.input_ids).squeeze(0)
|
| 445 |
+
part_atts = tokens.attention_mask.squeeze(0)
|
| 446 |
+
wrapped_embeds.append(part_embeds)
|
| 447 |
+
wrapped_atts.append(part_atts)
|
| 448 |
+
|
| 449 |
+
# Process each element in the batch to remove padding
|
| 450 |
+
if self.max_pooling:
|
| 451 |
+
audio_embeds[i] = list(audio_embeds[i].unbind(0))
|
| 452 |
+
audio_atts[i] = list(audio_atts[i].unbind(0))
|
| 453 |
+
for j in range(len(audio_embeds[i])):
|
| 454 |
+
audio_embeds[i][j] = audio_embeds[i][j][audio_atts[i][j]]
|
| 455 |
+
audio_atts[i][j] = audio_atts[i][j][audio_atts[i][j]]
|
| 456 |
+
|
| 457 |
+
# Interleave wrapped_embeds and audio_embeds using interleave_lists
|
| 458 |
+
wrapped_embeds = interleave_lists(wrapped_embeds, audio_embeds[i])
|
| 459 |
+
wrapped_atts = interleave_lists(wrapped_atts, audio_atts[i])
|
| 460 |
+
|
| 461 |
+
wrapped_embeds = torch.cat(wrapped_embeds, dim=0)
|
| 462 |
+
wrapped_atts = torch.cat(wrapped_atts, dim=0)
|
| 463 |
+
wrapped_embeds_list.append(wrapped_embeds)
|
| 464 |
+
wrapped_atts_list.append(wrapped_atts)
|
| 465 |
+
|
| 466 |
+
wrapped_embeds = pad_sequence(wrapped_embeds_list, batch_first=True)
|
| 467 |
+
wrapped_atts = pad_sequence(wrapped_atts_list, batch_first=True)
|
| 468 |
+
return wrapped_embeds, wrapped_atts
|
| 469 |
+
|
| 470 |
+
def forward(self, samples, verbose=True):
|
| 471 |
+
# Prepare prompts
|
| 472 |
+
prompt = samples["prompt"]
|
| 473 |
+
prompt = [self.prompt_template.format(p) for p in prompt]
|
| 474 |
+
|
| 475 |
+
# Use audio/audio encoder to encode audio/audio
|
| 476 |
+
raw_wav = samples.get("raw_wav", None)
|
| 477 |
+
audio_padding_mask = samples.get("padding_mask", None)
|
| 478 |
+
|
| 479 |
+
audio_embeds, audio_atts = self.encode_audio(raw_wav, audio_padding_mask)
|
| 480 |
+
audio_chunk_sizes = samples["audio_chunk_sizes"]
|
| 481 |
+
split_audio_embeds = list(torch.split(audio_embeds, audio_chunk_sizes, dim=0))
|
| 482 |
+
split_audio_atts = list(torch.split(audio_atts, audio_chunk_sizes, dim=0))
|
| 483 |
+
|
| 484 |
+
# Wrap audio_embeds with prompts
|
| 485 |
+
audio_embeds, audio_atts = self.prompt_wrap(split_audio_embeds, split_audio_atts, prompt)
|
| 486 |
+
|
| 487 |
+
# Prepare inputs for LLM
|
| 488 |
+
text = [t + self.end_sym for t in samples["text"]]
|
| 489 |
+
to_regress_tokens = self.llama_tokenizer(
|
| 490 |
+
text,
|
| 491 |
+
return_tensors="pt",
|
| 492 |
+
padding="longest",
|
| 493 |
+
truncation=True,
|
| 494 |
+
max_length=self.max_txt_len,
|
| 495 |
+
add_special_tokens=False,
|
| 496 |
+
).to(audio_embeds.device)
|
| 497 |
+
|
| 498 |
+
to_regress_embeds = self.llama_embed_tokens(to_regress_tokens.input_ids)
|
| 499 |
+
|
| 500 |
+
# Prepare targets
|
| 501 |
+
targets = to_regress_tokens.input_ids.masked_fill(
|
| 502 |
+
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
batch_size = audio_embeds.size(0)
|
| 506 |
+
|
| 507 |
+
# BOS token embeddings
|
| 508 |
+
bos_token_id = self.llama_tokenizer.bos_token_id
|
| 509 |
+
bos = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=audio_embeds.device)
|
| 510 |
+
bos_embeds = self.llama_embed_tokens(bos)
|
| 511 |
+
|
| 512 |
+
# Prepare lists to collect per-sample embeddings, attention masks, and targets
|
| 513 |
+
inputs_embeds_list = []
|
| 514 |
+
attention_mask_list = []
|
| 515 |
+
targets_list = []
|
| 516 |
+
|
| 517 |
+
for i in range(batch_size):
|
| 518 |
+
# Extract non-padded audio embeddings and attention mask
|
| 519 |
+
audio_embed = audio_embeds[i][audio_atts[i].bool()]
|
| 520 |
+
audio_att = audio_atts[i][audio_atts[i].bool()]
|
| 521 |
+
|
| 522 |
+
# Extract non-padded text embeddings and attention mask
|
| 523 |
+
text_embed = to_regress_embeds[i][to_regress_tokens.attention_mask[i].bool()]
|
| 524 |
+
text_att = to_regress_tokens.attention_mask[i][to_regress_tokens.attention_mask[i].bool()]
|
| 525 |
+
|
| 526 |
+
# Extract corresponding targets for the text tokens
|
| 527 |
+
target = targets[i][to_regress_tokens.attention_mask[i].bool()]
|
| 528 |
+
|
| 529 |
+
# Concatenate embeddings: BOS token, audio embeddings, text embeddings
|
| 530 |
+
input_embeds = torch.cat([bos_embeds[i], audio_embed, text_embed], dim=0)
|
| 531 |
+
|
| 532 |
+
# Concatenate attention masks: BOS token mask, audio attention mask, text attention mask
|
| 533 |
+
att_mask = torch.cat(
|
| 534 |
+
[
|
| 535 |
+
torch.ones(1, device=audio_embeds.device, dtype=audio_att.dtype),
|
| 536 |
+
audio_att,
|
| 537 |
+
text_att,
|
| 538 |
+
],
|
| 539 |
+
dim=0,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Create targets: Ignore index (-100) for BOS and audio tokens, actual targets for text tokens
|
| 543 |
+
ignore_targets = torch.full(
|
| 544 |
+
(1 + audio_embed.size(0),),
|
| 545 |
+
-100,
|
| 546 |
+
device=audio_embeds.device,
|
| 547 |
+
dtype=targets.dtype,
|
| 548 |
+
)
|
| 549 |
+
sample_targets = torch.cat([ignore_targets, target], dim=0)
|
| 550 |
+
|
| 551 |
+
# Append to lists
|
| 552 |
+
inputs_embeds_list.append(input_embeds)
|
| 553 |
+
attention_mask_list.append(att_mask)
|
| 554 |
+
targets_list.append(sample_targets)
|
| 555 |
+
|
| 556 |
+
# Pad sequences to the maximum length in the batch
|
| 557 |
+
inputs_embeds_padded = pad_sequence(inputs_embeds_list, batch_first=True)
|
| 558 |
+
attention_mask_padded = pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
|
| 559 |
+
targets_padded = pad_sequence(targets_list, batch_first=True, padding_value=-100)
|
| 560 |
+
|
| 561 |
+
# Now use the padded embeddings, attention masks, and targets in the model
|
| 562 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
| 563 |
+
outputs = self.llama_model(
|
| 564 |
+
inputs_embeds=inputs_embeds_padded,
|
| 565 |
+
attention_mask=attention_mask_padded,
|
| 566 |
+
return_dict=True,
|
| 567 |
+
labels=targets_padded,
|
| 568 |
+
)
|
| 569 |
+
loss = outputs.loss # Original batch loss
|
| 570 |
+
|
| 571 |
+
# Compute per-example loss
|
| 572 |
+
nvocab = self.llama_model.config.vocab_size
|
| 573 |
+
logits = outputs.logits
|
| 574 |
+
|
| 575 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 576 |
+
shift_labels = targets_padded[..., 1:].contiguous()
|
| 577 |
+
|
| 578 |
+
# Compute loss per token
|
| 579 |
+
loss_fct_per_example = CrossEntropyLoss(reduction="none")
|
| 580 |
+
loss_per_token = loss_fct_per_example(
|
| 581 |
+
shift_logits.view(-1, nvocab), # Flatten to [batch_size * (seq_len-1), vocab_size]
|
| 582 |
+
shift_labels.view(-1), # Flatten to [batch_size * (seq_len-1)]
|
| 583 |
+
)
|
| 584 |
+
loss_per_token = loss_per_token.view(shift_labels.size()) # Reshape back to [batch_size, seq_len-1]
|
| 585 |
+
|
| 586 |
+
# Create mask
|
| 587 |
+
mask = shift_labels != -100 # [batch_size, seq_len-1]
|
| 588 |
+
|
| 589 |
+
# Apply mask to loss_per_token
|
| 590 |
+
loss_per_token = loss_per_token * mask.float()
|
| 591 |
+
|
| 592 |
+
# Compute per-example loss
|
| 593 |
+
loss_per_example = loss_per_token.sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
| 594 |
+
|
| 595 |
+
if verbose:
|
| 596 |
+
# Calculate predictions
|
| 597 |
+
predicted_tokens = shift_logits.argmax(dim=-1) # [batch_size, seq_len-1]
|
| 598 |
+
|
| 599 |
+
# Compute per-example correct counts
|
| 600 |
+
correct_per_sample = ((predicted_tokens == shift_labels) & mask).sum(dim=1).float() # [batch_size]
|
| 601 |
+
total_tokens_per_sample = mask.sum(dim=1).float() # [batch_size]
|
| 602 |
+
|
| 603 |
+
# Total correct and total tokens across the batch
|
| 604 |
+
correct = correct_per_sample.sum()
|
| 605 |
+
total = total_tokens_per_sample.sum()
|
| 606 |
+
|
| 607 |
+
return {
|
| 608 |
+
"loss": loss,
|
| 609 |
+
"correct": correct,
|
| 610 |
+
"total": total,
|
| 611 |
+
"per_example_loss": loss_per_example,
|
| 612 |
+
"correct_per_sample": correct_per_sample,
|
| 613 |
+
"total_per_sample": total_tokens_per_sample,
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
return {"loss": loss, "per_example_loss": loss_per_example}
|
| 617 |
+
|
| 618 |
+
@torch.inference_mode()
|
| 619 |
+
def generate(self, samples, generate_cfg, prompts):
|
| 620 |
+
batch_size = len(prompts)
|
| 621 |
+
|
| 622 |
+
raw_wav = samples["raw_wav"]
|
| 623 |
+
audio_padding_mask = samples.get("padding_mask", None)
|
| 624 |
+
|
| 625 |
+
audio_embeds, audio_atts = self.encode_audio(raw_wav, audio_padding_mask=audio_padding_mask)
|
| 626 |
+
split_audio_embeds = list(torch.split(audio_embeds, samples["audio_chunk_sizes"], dim=0))
|
| 627 |
+
split_audio_atts = list(torch.split(audio_atts, samples["audio_chunk_sizes"], dim=0))
|
| 628 |
+
audio_embeds, audio_atts = self.prompt_wrap(split_audio_embeds, split_audio_atts, prompts)
|
| 629 |
+
bos = (
|
| 630 |
+
torch.ones(
|
| 631 |
+
[batch_size, 1],
|
| 632 |
+
dtype=torch.int32,
|
| 633 |
+
device=audio_embeds.device,
|
| 634 |
+
)
|
| 635 |
+
* self.llama_tokenizer.bos_token_id
|
| 636 |
+
)
|
| 637 |
+
bos_embeds = self.llama_embed_tokens(bos)
|
| 638 |
+
atts_bos = audio_atts[:, :1]
|
| 639 |
+
|
| 640 |
+
embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
|
| 641 |
+
|
| 642 |
+
attns = torch.cat([atts_bos, audio_atts], dim=1)
|
| 643 |
+
|
| 644 |
+
stop_words_ids = [torch.tensor([2]).to(audio_embeds.device)]
|
| 645 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 646 |
+
|
| 647 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
| 648 |
+
outputs = self.llama_model.generate( # TODO: Wrap the llama_model with outlines https://outlines-dev.github.io/outlines/reference/models/transformers/
|
| 649 |
+
inputs_embeds=embeds.bfloat16(),
|
| 650 |
+
max_new_tokens=generate_cfg.max_new_tokens,
|
| 651 |
+
stopping_criteria=stopping_criteria,
|
| 652 |
+
num_beams=generate_cfg.num_beams,
|
| 653 |
+
do_sample=generate_cfg.do_sample,
|
| 654 |
+
min_length=generate_cfg.min_length,
|
| 655 |
+
temperature=generate_cfg.temperature,
|
| 656 |
+
# top_p=generate_cfg.get("top_p", 0.9),
|
| 657 |
+
repetition_penalty=generate_cfg.repetition_penalty,
|
| 658 |
+
length_penalty=generate_cfg.length_penalty,
|
| 659 |
+
attention_mask=attns.bfloat16(),
|
| 660 |
+
# prefix_allowed_tokens_fn=prefix_tokens_fn
|
| 661 |
+
# logits_processor=None
|
| 662 |
+
# constraints=[constraint] if constraint is not None else None
|
| 663 |
+
)
|
| 664 |
+
text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 665 |
+
|
| 666 |
+
return text
|
NatureLM/models/Qformer.py
ADDED
|
@@ -0,0 +1,1091 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from salesforce@LAVIS. Below is the original copyright:
|
| 3 |
+
* Copyright (c) 2023, salesforce.com, inc.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 6 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 7 |
+
* By Junnan Li
|
| 8 |
+
* Based on huggingface code base
|
| 9 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.utils.checkpoint
|
| 17 |
+
from torch import Tensor, device, nn
|
| 18 |
+
from torch.nn import CrossEntropyLoss
|
| 19 |
+
from transformers.activations import ACT2FN
|
| 20 |
+
from transformers.modeling_outputs import (
|
| 21 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 22 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 23 |
+
CausalLMOutputWithCrossAttentions,
|
| 24 |
+
MaskedLMOutput,
|
| 25 |
+
)
|
| 26 |
+
from transformers.modeling_utils import (
|
| 27 |
+
PreTrainedModel,
|
| 28 |
+
apply_chunking_to_forward,
|
| 29 |
+
find_pruneable_heads_and_indices,
|
| 30 |
+
prune_linear_layer,
|
| 31 |
+
)
|
| 32 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 33 |
+
from transformers.utils import logging
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BertEmbeddings(nn.Module):
|
| 39 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, config):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 44 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 45 |
+
|
| 46 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 47 |
+
# any TensorFlow checkpoint file
|
| 48 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 49 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 50 |
+
|
| 51 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 52 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 53 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 54 |
+
|
| 55 |
+
self.config = config
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
input_ids=None,
|
| 60 |
+
position_ids=None,
|
| 61 |
+
query_embeds=None,
|
| 62 |
+
past_key_values_length=0,
|
| 63 |
+
):
|
| 64 |
+
if input_ids is not None:
|
| 65 |
+
seq_length = input_ids.size()[1]
|
| 66 |
+
else:
|
| 67 |
+
seq_length = 0
|
| 68 |
+
|
| 69 |
+
if position_ids is None:
|
| 70 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
| 71 |
+
|
| 72 |
+
if input_ids is not None:
|
| 73 |
+
embeddings = self.word_embeddings(input_ids)
|
| 74 |
+
if self.position_embedding_type == "absolute":
|
| 75 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 76 |
+
embeddings = embeddings + position_embeddings
|
| 77 |
+
|
| 78 |
+
if query_embeds is not None:
|
| 79 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
| 80 |
+
else:
|
| 81 |
+
embeddings = query_embeds
|
| 82 |
+
|
| 83 |
+
embeddings = self.LayerNorm(embeddings)
|
| 84 |
+
embeddings = self.dropout(embeddings)
|
| 85 |
+
return embeddings
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class BertSelfAttention(nn.Module):
|
| 89 |
+
def __init__(self, config, is_cross_attention):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.config = config
|
| 92 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 93 |
+
raise ValueError(
|
| 94 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 95 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.num_attention_heads = config.num_attention_heads
|
| 99 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 100 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 101 |
+
|
| 102 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 103 |
+
if is_cross_attention:
|
| 104 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 105 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 106 |
+
else:
|
| 107 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 108 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 109 |
+
|
| 110 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 111 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 112 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 113 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 114 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 115 |
+
self.save_attention = False
|
| 116 |
+
|
| 117 |
+
def save_attn_gradients(self, attn_gradients):
|
| 118 |
+
self.attn_gradients = attn_gradients
|
| 119 |
+
|
| 120 |
+
def get_attn_gradients(self):
|
| 121 |
+
return self.attn_gradients
|
| 122 |
+
|
| 123 |
+
def save_attention_map(self, attention_map):
|
| 124 |
+
self.attention_map = attention_map
|
| 125 |
+
|
| 126 |
+
def get_attention_map(self):
|
| 127 |
+
return self.attention_map
|
| 128 |
+
|
| 129 |
+
def transpose_for_scores(self, x):
|
| 130 |
+
new_x_shape = x.size()[:-1] + (
|
| 131 |
+
self.num_attention_heads,
|
| 132 |
+
self.attention_head_size,
|
| 133 |
+
)
|
| 134 |
+
x = x.view(*new_x_shape)
|
| 135 |
+
return x.permute(0, 2, 1, 3)
|
| 136 |
+
|
| 137 |
+
def forward(
|
| 138 |
+
self,
|
| 139 |
+
hidden_states,
|
| 140 |
+
attention_mask=None,
|
| 141 |
+
head_mask=None,
|
| 142 |
+
encoder_hidden_states=None,
|
| 143 |
+
encoder_attention_mask=None,
|
| 144 |
+
past_key_value=None,
|
| 145 |
+
output_attentions=False,
|
| 146 |
+
):
|
| 147 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 148 |
+
# and values come from an encoder; the attention mask needs to be
|
| 149 |
+
# such that the encoder's padding tokens are not attended to.
|
| 150 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 151 |
+
|
| 152 |
+
if is_cross_attention:
|
| 153 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 154 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 155 |
+
attention_mask = encoder_attention_mask
|
| 156 |
+
elif past_key_value is not None:
|
| 157 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 158 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 159 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 160 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 161 |
+
else:
|
| 162 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 163 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 164 |
+
|
| 165 |
+
mixed_query_layer = self.query(hidden_states)
|
| 166 |
+
|
| 167 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 168 |
+
|
| 169 |
+
past_key_value = (key_layer, value_layer)
|
| 170 |
+
|
| 171 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 172 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 173 |
+
|
| 174 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 175 |
+
seq_length = hidden_states.size()[1]
|
| 176 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 177 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 178 |
+
distance = position_ids_l - position_ids_r
|
| 179 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 180 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 181 |
+
|
| 182 |
+
if self.position_embedding_type == "relative_key":
|
| 183 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 184 |
+
attention_scores = attention_scores + relative_position_scores
|
| 185 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 186 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 187 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 188 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 189 |
+
|
| 190 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 191 |
+
if attention_mask is not None:
|
| 192 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 193 |
+
attention_scores = attention_scores + attention_mask
|
| 194 |
+
|
| 195 |
+
# Normalize the attention scores to probabilities.
|
| 196 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 197 |
+
|
| 198 |
+
if is_cross_attention and self.save_attention:
|
| 199 |
+
self.save_attention_map(attention_probs)
|
| 200 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 201 |
+
|
| 202 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 203 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 204 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 205 |
+
|
| 206 |
+
# Mask heads if we want to
|
| 207 |
+
if head_mask is not None:
|
| 208 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 209 |
+
|
| 210 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 211 |
+
|
| 212 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 213 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 214 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 215 |
+
|
| 216 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 217 |
+
|
| 218 |
+
outputs = outputs + (past_key_value,)
|
| 219 |
+
return outputs
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class BertSelfOutput(nn.Module):
|
| 223 |
+
def __init__(self, config):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 226 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 227 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 228 |
+
|
| 229 |
+
def forward(self, hidden_states, input_tensor):
|
| 230 |
+
hidden_states = self.dense(hidden_states)
|
| 231 |
+
hidden_states = self.dropout(hidden_states)
|
| 232 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 233 |
+
return hidden_states
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class BertAttention(nn.Module):
|
| 237 |
+
def __init__(self, config, is_cross_attention=False):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 240 |
+
self.output = BertSelfOutput(config)
|
| 241 |
+
self.pruned_heads = set()
|
| 242 |
+
|
| 243 |
+
def prune_heads(self, heads):
|
| 244 |
+
if len(heads) == 0:
|
| 245 |
+
return
|
| 246 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 247 |
+
heads,
|
| 248 |
+
self.self.num_attention_heads,
|
| 249 |
+
self.self.attention_head_size,
|
| 250 |
+
self.pruned_heads,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Prune linear layers
|
| 254 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 255 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 256 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 257 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 258 |
+
|
| 259 |
+
# Update hyper params and store pruned heads
|
| 260 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 261 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 262 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
hidden_states,
|
| 267 |
+
attention_mask=None,
|
| 268 |
+
head_mask=None,
|
| 269 |
+
encoder_hidden_states=None,
|
| 270 |
+
encoder_attention_mask=None,
|
| 271 |
+
past_key_value=None,
|
| 272 |
+
output_attentions=False,
|
| 273 |
+
):
|
| 274 |
+
self_outputs = self.self(
|
| 275 |
+
hidden_states,
|
| 276 |
+
attention_mask,
|
| 277 |
+
head_mask,
|
| 278 |
+
encoder_hidden_states,
|
| 279 |
+
encoder_attention_mask,
|
| 280 |
+
past_key_value,
|
| 281 |
+
output_attentions,
|
| 282 |
+
)
|
| 283 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 284 |
+
|
| 285 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 286 |
+
return outputs
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class BertIntermediate(nn.Module):
|
| 290 |
+
def __init__(self, config):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 293 |
+
if isinstance(config.hidden_act, str):
|
| 294 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 295 |
+
else:
|
| 296 |
+
self.intermediate_act_fn = config.hidden_act
|
| 297 |
+
|
| 298 |
+
def forward(self, hidden_states):
|
| 299 |
+
hidden_states = self.dense(hidden_states)
|
| 300 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 301 |
+
return hidden_states
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class BertOutput(nn.Module):
|
| 305 |
+
def __init__(self, config):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 308 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 309 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 310 |
+
|
| 311 |
+
def forward(self, hidden_states, input_tensor):
|
| 312 |
+
hidden_states = self.dense(hidden_states)
|
| 313 |
+
hidden_states = self.dropout(hidden_states)
|
| 314 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 315 |
+
return hidden_states
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class BertLayer(nn.Module):
|
| 319 |
+
def __init__(self, config, layer_num):
|
| 320 |
+
super().__init__()
|
| 321 |
+
self.config = config
|
| 322 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 323 |
+
self.seq_len_dim = 1
|
| 324 |
+
self.attention = BertAttention(config)
|
| 325 |
+
self.layer_num = layer_num
|
| 326 |
+
if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
|
| 327 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
| 328 |
+
self.has_cross_attention = True
|
| 329 |
+
else:
|
| 330 |
+
self.has_cross_attention = False
|
| 331 |
+
self.intermediate = BertIntermediate(config)
|
| 332 |
+
self.output = BertOutput(config)
|
| 333 |
+
|
| 334 |
+
self.intermediate_query = BertIntermediate(config)
|
| 335 |
+
self.output_query = BertOutput(config)
|
| 336 |
+
|
| 337 |
+
def forward(
|
| 338 |
+
self,
|
| 339 |
+
hidden_states,
|
| 340 |
+
attention_mask=None,
|
| 341 |
+
head_mask=None,
|
| 342 |
+
encoder_hidden_states=None,
|
| 343 |
+
encoder_attention_mask=None,
|
| 344 |
+
past_key_value=None,
|
| 345 |
+
output_attentions=False,
|
| 346 |
+
query_length=0,
|
| 347 |
+
):
|
| 348 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 349 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 350 |
+
self_attention_outputs = self.attention(
|
| 351 |
+
hidden_states,
|
| 352 |
+
attention_mask,
|
| 353 |
+
head_mask,
|
| 354 |
+
output_attentions=output_attentions,
|
| 355 |
+
past_key_value=self_attn_past_key_value,
|
| 356 |
+
)
|
| 357 |
+
attention_output = self_attention_outputs[0]
|
| 358 |
+
outputs = self_attention_outputs[1:-1]
|
| 359 |
+
|
| 360 |
+
present_key_value = self_attention_outputs[-1]
|
| 361 |
+
|
| 362 |
+
if query_length > 0:
|
| 363 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 364 |
+
|
| 365 |
+
if self.has_cross_attention:
|
| 366 |
+
assert (
|
| 367 |
+
encoder_hidden_states is not None
|
| 368 |
+
), "encoder_hidden_states must be given for cross-attention layers"
|
| 369 |
+
cross_attention_outputs = self.crossattention(
|
| 370 |
+
query_attention_output,
|
| 371 |
+
attention_mask,
|
| 372 |
+
head_mask,
|
| 373 |
+
encoder_hidden_states,
|
| 374 |
+
encoder_attention_mask,
|
| 375 |
+
output_attentions=output_attentions,
|
| 376 |
+
)
|
| 377 |
+
query_attention_output = cross_attention_outputs[0]
|
| 378 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 379 |
+
|
| 380 |
+
layer_output = apply_chunking_to_forward(
|
| 381 |
+
self.feed_forward_chunk_query,
|
| 382 |
+
self.chunk_size_feed_forward,
|
| 383 |
+
self.seq_len_dim,
|
| 384 |
+
query_attention_output,
|
| 385 |
+
)
|
| 386 |
+
if attention_output.shape[1] > query_length:
|
| 387 |
+
layer_output_text = apply_chunking_to_forward(
|
| 388 |
+
self.feed_forward_chunk,
|
| 389 |
+
self.chunk_size_feed_forward,
|
| 390 |
+
self.seq_len_dim,
|
| 391 |
+
attention_output[:, query_length:, :],
|
| 392 |
+
)
|
| 393 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
| 394 |
+
else:
|
| 395 |
+
layer_output = apply_chunking_to_forward(
|
| 396 |
+
self.feed_forward_chunk,
|
| 397 |
+
self.chunk_size_feed_forward,
|
| 398 |
+
self.seq_len_dim,
|
| 399 |
+
attention_output,
|
| 400 |
+
)
|
| 401 |
+
outputs = (layer_output,) + outputs
|
| 402 |
+
|
| 403 |
+
outputs = outputs + (present_key_value,)
|
| 404 |
+
|
| 405 |
+
return outputs
|
| 406 |
+
|
| 407 |
+
def feed_forward_chunk(self, attention_output):
|
| 408 |
+
intermediate_output = self.intermediate(attention_output)
|
| 409 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 410 |
+
return layer_output
|
| 411 |
+
|
| 412 |
+
def feed_forward_chunk_query(self, attention_output):
|
| 413 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 414 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 415 |
+
return layer_output
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class BertEncoder(nn.Module):
|
| 419 |
+
def __init__(self, config):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.config = config
|
| 422 |
+
self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
|
| 423 |
+
|
| 424 |
+
def forward(
|
| 425 |
+
self,
|
| 426 |
+
hidden_states,
|
| 427 |
+
attention_mask=None,
|
| 428 |
+
head_mask=None,
|
| 429 |
+
encoder_hidden_states=None,
|
| 430 |
+
encoder_attention_mask=None,
|
| 431 |
+
past_key_values=None,
|
| 432 |
+
use_cache=None,
|
| 433 |
+
output_attentions=False,
|
| 434 |
+
output_hidden_states=False,
|
| 435 |
+
return_dict=True,
|
| 436 |
+
query_length=0,
|
| 437 |
+
):
|
| 438 |
+
all_hidden_states = () if output_hidden_states else None
|
| 439 |
+
all_self_attentions = () if output_attentions else None
|
| 440 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 441 |
+
|
| 442 |
+
next_decoder_cache = () if use_cache else None
|
| 443 |
+
|
| 444 |
+
for i in range(self.config.num_hidden_layers):
|
| 445 |
+
layer_module = self.layer[i]
|
| 446 |
+
if output_hidden_states:
|
| 447 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 448 |
+
|
| 449 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 450 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 451 |
+
|
| 452 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
| 453 |
+
if use_cache:
|
| 454 |
+
logger.warn(
|
| 455 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 456 |
+
)
|
| 457 |
+
use_cache = False
|
| 458 |
+
|
| 459 |
+
def create_custom_forward(module):
|
| 460 |
+
def custom_forward(*inputs):
|
| 461 |
+
return module(*inputs, past_key_value, output_attentions, query_length)
|
| 462 |
+
|
| 463 |
+
return custom_forward
|
| 464 |
+
|
| 465 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 466 |
+
create_custom_forward(layer_module),
|
| 467 |
+
hidden_states,
|
| 468 |
+
attention_mask,
|
| 469 |
+
layer_head_mask,
|
| 470 |
+
encoder_hidden_states,
|
| 471 |
+
encoder_attention_mask,
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
layer_outputs = layer_module(
|
| 475 |
+
hidden_states,
|
| 476 |
+
attention_mask,
|
| 477 |
+
layer_head_mask,
|
| 478 |
+
encoder_hidden_states,
|
| 479 |
+
encoder_attention_mask,
|
| 480 |
+
past_key_value,
|
| 481 |
+
output_attentions,
|
| 482 |
+
query_length,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
hidden_states = layer_outputs[0]
|
| 486 |
+
if use_cache:
|
| 487 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 488 |
+
if output_attentions:
|
| 489 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 490 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 491 |
+
|
| 492 |
+
if output_hidden_states:
|
| 493 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 494 |
+
|
| 495 |
+
if not return_dict:
|
| 496 |
+
return tuple(
|
| 497 |
+
v
|
| 498 |
+
for v in [
|
| 499 |
+
hidden_states,
|
| 500 |
+
next_decoder_cache,
|
| 501 |
+
all_hidden_states,
|
| 502 |
+
all_self_attentions,
|
| 503 |
+
all_cross_attentions,
|
| 504 |
+
]
|
| 505 |
+
if v is not None
|
| 506 |
+
)
|
| 507 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 508 |
+
last_hidden_state=hidden_states,
|
| 509 |
+
past_key_values=next_decoder_cache,
|
| 510 |
+
hidden_states=all_hidden_states,
|
| 511 |
+
attentions=all_self_attentions,
|
| 512 |
+
cross_attentions=all_cross_attentions,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
class BertPooler(nn.Module):
|
| 517 |
+
def __init__(self, config):
|
| 518 |
+
super().__init__()
|
| 519 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 520 |
+
self.activation = nn.Tanh()
|
| 521 |
+
|
| 522 |
+
def forward(self, hidden_states):
|
| 523 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 524 |
+
# to the first token.
|
| 525 |
+
first_token_tensor = hidden_states[:, 0]
|
| 526 |
+
pooled_output = self.dense(first_token_tensor)
|
| 527 |
+
pooled_output = self.activation(pooled_output)
|
| 528 |
+
return pooled_output
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 532 |
+
def __init__(self, config):
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 535 |
+
if isinstance(config.hidden_act, str):
|
| 536 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 537 |
+
else:
|
| 538 |
+
self.transform_act_fn = config.hidden_act
|
| 539 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 540 |
+
|
| 541 |
+
def forward(self, hidden_states):
|
| 542 |
+
hidden_states = self.dense(hidden_states)
|
| 543 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 544 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 545 |
+
return hidden_states
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class BertLMPredictionHead(nn.Module):
|
| 549 |
+
def __init__(self, config):
|
| 550 |
+
super().__init__()
|
| 551 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 552 |
+
|
| 553 |
+
# The output weights are the same as the input embeddings, but there is
|
| 554 |
+
# an output-only bias for each token.
|
| 555 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 556 |
+
|
| 557 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 558 |
+
|
| 559 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 560 |
+
self.decoder.bias = self.bias
|
| 561 |
+
|
| 562 |
+
def forward(self, hidden_states):
|
| 563 |
+
hidden_states = self.transform(hidden_states)
|
| 564 |
+
hidden_states = self.decoder(hidden_states)
|
| 565 |
+
return hidden_states
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class BertOnlyMLMHead(nn.Module):
|
| 569 |
+
def __init__(self, config):
|
| 570 |
+
super().__init__()
|
| 571 |
+
self.predictions = BertLMPredictionHead(config)
|
| 572 |
+
|
| 573 |
+
def forward(self, sequence_output):
|
| 574 |
+
prediction_scores = self.predictions(sequence_output)
|
| 575 |
+
return prediction_scores
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 579 |
+
"""
|
| 580 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 581 |
+
models.
|
| 582 |
+
"""
|
| 583 |
+
|
| 584 |
+
config_class = BertConfig
|
| 585 |
+
base_model_prefix = "bert"
|
| 586 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 587 |
+
|
| 588 |
+
def _init_weights(self, module):
|
| 589 |
+
"""Initialize the weights"""
|
| 590 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 591 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 592 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 593 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 594 |
+
elif isinstance(module, nn.LayerNorm):
|
| 595 |
+
module.bias.data.zero_()
|
| 596 |
+
module.weight.data.fill_(1.0)
|
| 597 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 598 |
+
module.bias.data.zero_()
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class BertModel(BertPreTrainedModel):
|
| 602 |
+
"""
|
| 603 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 604 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 605 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 606 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 607 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 608 |
+
input to the forward pass.
|
| 609 |
+
"""
|
| 610 |
+
|
| 611 |
+
def __init__(self, config, add_pooling_layer=False):
|
| 612 |
+
super().__init__(config)
|
| 613 |
+
self.config = config
|
| 614 |
+
|
| 615 |
+
self.embeddings = BertEmbeddings(config)
|
| 616 |
+
|
| 617 |
+
self.encoder = BertEncoder(config)
|
| 618 |
+
|
| 619 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 620 |
+
|
| 621 |
+
self.init_weights()
|
| 622 |
+
|
| 623 |
+
def get_input_embeddings(self):
|
| 624 |
+
return self.embeddings.word_embeddings
|
| 625 |
+
|
| 626 |
+
def set_input_embeddings(self, value):
|
| 627 |
+
self.embeddings.word_embeddings = value
|
| 628 |
+
|
| 629 |
+
def _prune_heads(self, heads_to_prune):
|
| 630 |
+
"""
|
| 631 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 632 |
+
class PreTrainedModel
|
| 633 |
+
"""
|
| 634 |
+
for layer, heads in heads_to_prune.items():
|
| 635 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 636 |
+
|
| 637 |
+
def get_extended_attention_mask(
|
| 638 |
+
self,
|
| 639 |
+
attention_mask: Tensor,
|
| 640 |
+
input_shape: Tuple[int],
|
| 641 |
+
device: device,
|
| 642 |
+
is_decoder: bool,
|
| 643 |
+
has_query: bool = False,
|
| 644 |
+
) -> Tensor:
|
| 645 |
+
"""
|
| 646 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 647 |
+
|
| 648 |
+
Arguments:
|
| 649 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 650 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 651 |
+
input_shape (:obj:`Tuple[int]`):
|
| 652 |
+
The shape of the input to the model.
|
| 653 |
+
device: (:obj:`torch.device`):
|
| 654 |
+
The device of the input to the model.
|
| 655 |
+
|
| 656 |
+
Returns:
|
| 657 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 658 |
+
"""
|
| 659 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 660 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 661 |
+
if attention_mask.dim() == 3:
|
| 662 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 663 |
+
elif attention_mask.dim() == 2:
|
| 664 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 665 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 666 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 667 |
+
if is_decoder:
|
| 668 |
+
batch_size, seq_length = input_shape
|
| 669 |
+
|
| 670 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 671 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 672 |
+
|
| 673 |
+
# add a prefix ones mask to the causal mask
|
| 674 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 675 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 676 |
+
|
| 677 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 678 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 679 |
+
if has_query: # UniLM style attention mask
|
| 680 |
+
causal_mask = torch.cat(
|
| 681 |
+
[
|
| 682 |
+
torch.zeros(
|
| 683 |
+
(batch_size, prefix_seq_len, seq_length),
|
| 684 |
+
device=device,
|
| 685 |
+
dtype=causal_mask.dtype,
|
| 686 |
+
),
|
| 687 |
+
causal_mask,
|
| 688 |
+
],
|
| 689 |
+
axis=1,
|
| 690 |
+
)
|
| 691 |
+
causal_mask = torch.cat(
|
| 692 |
+
[
|
| 693 |
+
torch.ones(
|
| 694 |
+
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
| 695 |
+
device=device,
|
| 696 |
+
dtype=causal_mask.dtype,
|
| 697 |
+
),
|
| 698 |
+
causal_mask,
|
| 699 |
+
],
|
| 700 |
+
axis=-1,
|
| 701 |
+
)
|
| 702 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 703 |
+
else:
|
| 704 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 705 |
+
else:
|
| 706 |
+
raise ValueError(
|
| 707 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 708 |
+
input_shape, attention_mask.shape
|
| 709 |
+
)
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 713 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 714 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 715 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 716 |
+
# effectively the same as removing these entirely.
|
| 717 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 718 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 719 |
+
return extended_attention_mask
|
| 720 |
+
|
| 721 |
+
def forward(
|
| 722 |
+
self,
|
| 723 |
+
input_ids=None,
|
| 724 |
+
attention_mask=None,
|
| 725 |
+
position_ids=None,
|
| 726 |
+
head_mask=None,
|
| 727 |
+
query_embeds=None,
|
| 728 |
+
encoder_hidden_states=None,
|
| 729 |
+
encoder_attention_mask=None,
|
| 730 |
+
past_key_values=None,
|
| 731 |
+
use_cache=None,
|
| 732 |
+
output_attentions=None,
|
| 733 |
+
output_hidden_states=None,
|
| 734 |
+
return_dict=None,
|
| 735 |
+
is_decoder=False,
|
| 736 |
+
):
|
| 737 |
+
r"""
|
| 738 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 739 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 740 |
+
the model is configured as a decoder.
|
| 741 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 742 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 743 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 744 |
+
- 1 for tokens that are **not masked**,
|
| 745 |
+
- 0 for tokens that are **masked**.
|
| 746 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 747 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 748 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 749 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 750 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 751 |
+
use_cache (:obj:`bool`, `optional`):
|
| 752 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 753 |
+
decoding (see :obj:`past_key_values`).
|
| 754 |
+
"""
|
| 755 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 756 |
+
output_hidden_states = (
|
| 757 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 758 |
+
)
|
| 759 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 760 |
+
|
| 761 |
+
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 762 |
+
|
| 763 |
+
if input_ids is None:
|
| 764 |
+
assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
|
| 765 |
+
|
| 766 |
+
# past_key_values_length
|
| 767 |
+
past_key_values_length = (
|
| 768 |
+
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
| 772 |
+
|
| 773 |
+
embedding_output = self.embeddings(
|
| 774 |
+
input_ids=input_ids,
|
| 775 |
+
position_ids=position_ids,
|
| 776 |
+
query_embeds=query_embeds,
|
| 777 |
+
past_key_values_length=past_key_values_length,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
input_shape = embedding_output.size()[:-1]
|
| 781 |
+
batch_size, seq_length = input_shape
|
| 782 |
+
device = embedding_output.device
|
| 783 |
+
|
| 784 |
+
if attention_mask is None:
|
| 785 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 786 |
+
|
| 787 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 788 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 789 |
+
if is_decoder:
|
| 790 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
| 791 |
+
attention_mask,
|
| 792 |
+
input_ids.shape,
|
| 793 |
+
device,
|
| 794 |
+
is_decoder,
|
| 795 |
+
has_query=(query_embeds is not None),
|
| 796 |
+
)
|
| 797 |
+
else:
|
| 798 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
|
| 799 |
+
|
| 800 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 801 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 802 |
+
if encoder_hidden_states is not None:
|
| 803 |
+
if isinstance(encoder_hidden_states, list):
|
| 804 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 805 |
+
else:
|
| 806 |
+
(
|
| 807 |
+
encoder_batch_size,
|
| 808 |
+
encoder_sequence_length,
|
| 809 |
+
_,
|
| 810 |
+
) = encoder_hidden_states.size()
|
| 811 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 812 |
+
|
| 813 |
+
if isinstance(encoder_attention_mask, list):
|
| 814 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 815 |
+
elif encoder_attention_mask is None:
|
| 816 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 817 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 818 |
+
else:
|
| 819 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 820 |
+
else:
|
| 821 |
+
encoder_extended_attention_mask = None
|
| 822 |
+
|
| 823 |
+
# Prepare head mask if needed
|
| 824 |
+
# 1.0 in head_mask indicate we keep the head
|
| 825 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 826 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 827 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 828 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 829 |
+
|
| 830 |
+
encoder_outputs = self.encoder(
|
| 831 |
+
embedding_output,
|
| 832 |
+
attention_mask=extended_attention_mask,
|
| 833 |
+
head_mask=head_mask,
|
| 834 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 835 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 836 |
+
past_key_values=past_key_values,
|
| 837 |
+
use_cache=use_cache,
|
| 838 |
+
output_attentions=output_attentions,
|
| 839 |
+
output_hidden_states=output_hidden_states,
|
| 840 |
+
return_dict=return_dict,
|
| 841 |
+
query_length=query_length,
|
| 842 |
+
)
|
| 843 |
+
sequence_output = encoder_outputs[0]
|
| 844 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 845 |
+
|
| 846 |
+
if not return_dict:
|
| 847 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 848 |
+
|
| 849 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 850 |
+
last_hidden_state=sequence_output,
|
| 851 |
+
pooler_output=pooled_output,
|
| 852 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 853 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 854 |
+
attentions=encoder_outputs.attentions,
|
| 855 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 860 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 861 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 862 |
+
|
| 863 |
+
def __init__(self, config):
|
| 864 |
+
super().__init__(config)
|
| 865 |
+
|
| 866 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 867 |
+
self.cls = BertOnlyMLMHead(config)
|
| 868 |
+
|
| 869 |
+
self.init_weights()
|
| 870 |
+
|
| 871 |
+
def get_output_embeddings(self):
|
| 872 |
+
return self.cls.predictions.decoder
|
| 873 |
+
|
| 874 |
+
def set_output_embeddings(self, new_embeddings):
|
| 875 |
+
self.cls.predictions.decoder = new_embeddings
|
| 876 |
+
|
| 877 |
+
def forward(
|
| 878 |
+
self,
|
| 879 |
+
input_ids=None,
|
| 880 |
+
attention_mask=None,
|
| 881 |
+
position_ids=None,
|
| 882 |
+
head_mask=None,
|
| 883 |
+
query_embeds=None,
|
| 884 |
+
encoder_hidden_states=None,
|
| 885 |
+
encoder_attention_mask=None,
|
| 886 |
+
labels=None,
|
| 887 |
+
past_key_values=None,
|
| 888 |
+
use_cache=True,
|
| 889 |
+
output_attentions=None,
|
| 890 |
+
output_hidden_states=None,
|
| 891 |
+
return_dict=None,
|
| 892 |
+
return_logits=False,
|
| 893 |
+
is_decoder=True,
|
| 894 |
+
reduction="mean",
|
| 895 |
+
):
|
| 896 |
+
r"""
|
| 897 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 898 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 899 |
+
the model is configured as a decoder.
|
| 900 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 901 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 902 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 903 |
+
- 1 for tokens that are **not masked**,
|
| 904 |
+
- 0 for tokens that are **masked**.
|
| 905 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 906 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 907 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
| 908 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
| 909 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 910 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 911 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 912 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 913 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 914 |
+
use_cache (:obj:`bool`, `optional`):
|
| 915 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 916 |
+
decoding (see :obj:`past_key_values`).
|
| 917 |
+
Returns:
|
| 918 |
+
Example::
|
| 919 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 920 |
+
>>> import torch
|
| 921 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 922 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 923 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 924 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 925 |
+
>>> outputs = model(**inputs)
|
| 926 |
+
>>> prediction_logits = outputs.logits
|
| 927 |
+
"""
|
| 928 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 929 |
+
if labels is not None:
|
| 930 |
+
use_cache = False
|
| 931 |
+
if past_key_values is not None:
|
| 932 |
+
query_embeds = None
|
| 933 |
+
|
| 934 |
+
outputs = self.bert(
|
| 935 |
+
input_ids,
|
| 936 |
+
attention_mask=attention_mask,
|
| 937 |
+
position_ids=position_ids,
|
| 938 |
+
head_mask=head_mask,
|
| 939 |
+
query_embeds=query_embeds,
|
| 940 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 941 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 942 |
+
past_key_values=past_key_values,
|
| 943 |
+
use_cache=use_cache,
|
| 944 |
+
output_attentions=output_attentions,
|
| 945 |
+
output_hidden_states=output_hidden_states,
|
| 946 |
+
return_dict=return_dict,
|
| 947 |
+
is_decoder=is_decoder,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
sequence_output = outputs[0]
|
| 951 |
+
if query_embeds is not None:
|
| 952 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
| 953 |
+
|
| 954 |
+
prediction_scores = self.cls(sequence_output)
|
| 955 |
+
|
| 956 |
+
if return_logits:
|
| 957 |
+
return prediction_scores[:, :-1, :].contiguous()
|
| 958 |
+
|
| 959 |
+
lm_loss = None
|
| 960 |
+
if labels is not None:
|
| 961 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 962 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 963 |
+
labels = labels[:, 1:].contiguous()
|
| 964 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
| 965 |
+
lm_loss = loss_fct(
|
| 966 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
| 967 |
+
labels.view(-1),
|
| 968 |
+
)
|
| 969 |
+
if reduction == "none":
|
| 970 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
| 971 |
+
|
| 972 |
+
if not return_dict:
|
| 973 |
+
output = (prediction_scores,) + outputs[2:]
|
| 974 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 975 |
+
|
| 976 |
+
return CausalLMOutputWithCrossAttentions(
|
| 977 |
+
loss=lm_loss,
|
| 978 |
+
logits=prediction_scores,
|
| 979 |
+
past_key_values=outputs.past_key_values,
|
| 980 |
+
hidden_states=outputs.hidden_states,
|
| 981 |
+
attentions=outputs.attentions,
|
| 982 |
+
cross_attentions=outputs.cross_attentions,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
|
| 986 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 987 |
+
if attention_mask is None:
|
| 988 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
| 989 |
+
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
| 990 |
+
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
| 991 |
+
|
| 992 |
+
# cut decoder_input_ids if past is used
|
| 993 |
+
if past is not None:
|
| 994 |
+
input_ids = input_ids[:, -1:]
|
| 995 |
+
|
| 996 |
+
return {
|
| 997 |
+
"input_ids": input_ids,
|
| 998 |
+
"query_embeds": query_embeds,
|
| 999 |
+
"attention_mask": attention_mask,
|
| 1000 |
+
"past_key_values": past,
|
| 1001 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
| 1002 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
| 1003 |
+
"is_decoder": True,
|
| 1004 |
+
}
|
| 1005 |
+
|
| 1006 |
+
def _reorder_cache(self, past, beam_idx):
|
| 1007 |
+
reordered_past = ()
|
| 1008 |
+
for layer_past in past:
|
| 1009 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 1010 |
+
return reordered_past
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
| 1014 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 1015 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 1016 |
+
|
| 1017 |
+
def __init__(self, config):
|
| 1018 |
+
super().__init__(config)
|
| 1019 |
+
|
| 1020 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1021 |
+
self.cls = BertOnlyMLMHead(config)
|
| 1022 |
+
|
| 1023 |
+
self.init_weights()
|
| 1024 |
+
|
| 1025 |
+
def get_output_embeddings(self):
|
| 1026 |
+
return self.cls.predictions.decoder
|
| 1027 |
+
|
| 1028 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1029 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1030 |
+
|
| 1031 |
+
def forward(
|
| 1032 |
+
self,
|
| 1033 |
+
input_ids=None,
|
| 1034 |
+
attention_mask=None,
|
| 1035 |
+
position_ids=None,
|
| 1036 |
+
head_mask=None,
|
| 1037 |
+
query_embeds=None,
|
| 1038 |
+
encoder_hidden_states=None,
|
| 1039 |
+
encoder_attention_mask=None,
|
| 1040 |
+
labels=None,
|
| 1041 |
+
output_attentions=None,
|
| 1042 |
+
output_hidden_states=None,
|
| 1043 |
+
return_dict=None,
|
| 1044 |
+
return_logits=False,
|
| 1045 |
+
is_decoder=False,
|
| 1046 |
+
):
|
| 1047 |
+
r"""
|
| 1048 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 1049 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
| 1050 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
| 1051 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
| 1052 |
+
"""
|
| 1053 |
+
|
| 1054 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1055 |
+
|
| 1056 |
+
outputs = self.bert(
|
| 1057 |
+
input_ids,
|
| 1058 |
+
attention_mask=attention_mask,
|
| 1059 |
+
position_ids=position_ids,
|
| 1060 |
+
head_mask=head_mask,
|
| 1061 |
+
query_embeds=query_embeds,
|
| 1062 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1063 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1064 |
+
output_attentions=output_attentions,
|
| 1065 |
+
output_hidden_states=output_hidden_states,
|
| 1066 |
+
return_dict=return_dict,
|
| 1067 |
+
is_decoder=is_decoder,
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
if query_embeds is not None:
|
| 1071 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
| 1072 |
+
prediction_scores = self.cls(sequence_output)
|
| 1073 |
+
|
| 1074 |
+
if return_logits:
|
| 1075 |
+
return prediction_scores
|
| 1076 |
+
|
| 1077 |
+
masked_lm_loss = None
|
| 1078 |
+
if labels is not None:
|
| 1079 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1080 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1081 |
+
|
| 1082 |
+
if not return_dict:
|
| 1083 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1084 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1085 |
+
|
| 1086 |
+
return MaskedLMOutput(
|
| 1087 |
+
loss=masked_lm_loss,
|
| 1088 |
+
logits=prediction_scores,
|
| 1089 |
+
hidden_states=outputs.hidden_states,
|
| 1090 |
+
attentions=outputs.attentions,
|
| 1091 |
+
)
|
NatureLM/models/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Earth Species Project
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .NatureLM import NatureLM
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_model(config):
|
| 19 |
+
return NatureLM.from_config(config)
|
NatureLM/models/__pycache__/NatureLM.cpython-310.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
NatureLM/models/__pycache__/Qformer.cpython-310.pyc
ADDED
|
Binary file (30 kB). View file
|
|
|
NatureLM/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (329 Bytes). View file
|
|
|
NatureLM/models/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (926 Bytes). View file
|
|
|
NatureLM/models/aves.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torchaudio.models import wav2vec2_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AvesEmbedding(nn.Module):
|
| 10 |
+
def __init__(self, sr, large=False):
|
| 11 |
+
super().__init__()
|
| 12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
|
| 14 |
+
# reference: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2vec2/utils/import_fairseq.html
|
| 15 |
+
if large:
|
| 16 |
+
config = self.load_config("configs/birdaves_bioxlarge.config")
|
| 17 |
+
else:
|
| 18 |
+
config = self.load_config("configs/birdaves_bioxbase.config")
|
| 19 |
+
self.model = wav2vec2_model(**config, aux_num_out=None)
|
| 20 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 21 |
+
"https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-base.torchaudio.pt",
|
| 22 |
+
map_location=device,
|
| 23 |
+
)
|
| 24 |
+
self.model.load_state_dict(state_dict)
|
| 25 |
+
self.model.feature_extractor.requires_grad_(True)
|
| 26 |
+
|
| 27 |
+
# bundle = torchaudio.pipelines.WAV2VEC2_BASE
|
| 28 |
+
# self.model = bundle.get_model()
|
| 29 |
+
|
| 30 |
+
self.sr = sr
|
| 31 |
+
|
| 32 |
+
def load_config(self, config_path):
|
| 33 |
+
with open(config_path, "r") as ff:
|
| 34 |
+
obj = json.load(ff)
|
| 35 |
+
|
| 36 |
+
return obj
|
| 37 |
+
|
| 38 |
+
def forward(self, sig, padding_mask):
|
| 39 |
+
# extract_feature in the torchaudio version will output all 12 layers' output, -1 to select the final one
|
| 40 |
+
# print("sig", sig)
|
| 41 |
+
|
| 42 |
+
out = self.model.extract_features(sig.float())[0][-1]
|
| 43 |
+
atts = ~padding_mask
|
| 44 |
+
atts = atts.unsqueeze(1).float()
|
| 45 |
+
atts = F.max_pool1d(atts, kernel_size=320, stride=320)
|
| 46 |
+
atts = atts > 0
|
| 47 |
+
padding_mask = ~atts
|
| 48 |
+
|
| 49 |
+
return out, padding_mask
|
| 50 |
+
|
| 51 |
+
def freeze(self):
|
| 52 |
+
for param in self.model.encoder.parameters():
|
| 53 |
+
param.requires_grad = False
|
| 54 |
+
self.model.feature_extractor.requires_grad_(False)
|
| 55 |
+
|
| 56 |
+
def unfreeze(self):
|
| 57 |
+
for param in self.model.encoder.parameters():
|
| 58 |
+
param.requires_grad = True
|
| 59 |
+
self.model.feature_extractor.requires_grad_(True)
|
NatureLM/models/beats/BEATs.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
| 17 |
+
from torch.nn import LayerNorm
|
| 18 |
+
|
| 19 |
+
from .backbone import TransformerEncoder
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BEATsConfig:
|
| 25 |
+
def __init__(self, cfg=None):
|
| 26 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
| 27 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
| 28 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
| 29 |
+
|
| 30 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
| 31 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
| 32 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
| 33 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
| 34 |
+
self.activation_fn: str = "gelu" # activation function to use
|
| 35 |
+
|
| 36 |
+
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
| 37 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
| 38 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
| 39 |
+
|
| 40 |
+
# dropouts
|
| 41 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
| 42 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
| 43 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
| 44 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
| 45 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
| 46 |
+
|
| 47 |
+
# positional embeddings
|
| 48 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
| 49 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
| 50 |
+
|
| 51 |
+
# relative position embedding
|
| 52 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
| 53 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
| 54 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
| 55 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
| 56 |
+
|
| 57 |
+
# label predictor
|
| 58 |
+
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
| 59 |
+
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
| 60 |
+
self.predictor_class: int = 527 # target class number for the predictor
|
| 61 |
+
|
| 62 |
+
if cfg is not None:
|
| 63 |
+
self.update(cfg)
|
| 64 |
+
|
| 65 |
+
def update(self, cfg: dict):
|
| 66 |
+
self.__dict__.update(cfg)
|
| 67 |
+
|
| 68 |
+
def to_dict(self):
|
| 69 |
+
return self.__dict__
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class BEATs(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
cfg: BEATsConfig,
|
| 76 |
+
) -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
logger.info(f"BEATs Config: {cfg.__dict__}")
|
| 79 |
+
|
| 80 |
+
self.cfg = cfg
|
| 81 |
+
|
| 82 |
+
self.embed = cfg.embed_dim
|
| 83 |
+
self.post_extract_proj = (
|
| 84 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.input_patch_size = cfg.input_patch_size
|
| 88 |
+
self.patch_embedding = nn.Conv2d(
|
| 89 |
+
1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=cfg.conv_bias
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 93 |
+
|
| 94 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
| 95 |
+
self.encoder = TransformerEncoder(cfg)
|
| 96 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 97 |
+
|
| 98 |
+
if cfg.finetuned_model:
|
| 99 |
+
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
| 100 |
+
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
| 101 |
+
else:
|
| 102 |
+
self.predictor = None
|
| 103 |
+
|
| 104 |
+
def forward_padding_mask(
|
| 105 |
+
self,
|
| 106 |
+
features: torch.Tensor,
|
| 107 |
+
padding_mask: torch.Tensor,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 110 |
+
if extra > 0:
|
| 111 |
+
padding_mask = padding_mask[:, :-extra]
|
| 112 |
+
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
| 113 |
+
padding_mask = padding_mask.all(-1)
|
| 114 |
+
return padding_mask
|
| 115 |
+
|
| 116 |
+
def preprocess(
|
| 117 |
+
self,
|
| 118 |
+
source: torch.Tensor,
|
| 119 |
+
fbank_mean: float = 15.41663,
|
| 120 |
+
fbank_std: float = 6.55582,
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
fbanks = []
|
| 123 |
+
for waveform in source:
|
| 124 |
+
waveform = waveform.unsqueeze(0) * 2**15
|
| 125 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
| 126 |
+
fbanks.append(fbank)
|
| 127 |
+
fbank = torch.stack(fbanks, dim=0)
|
| 128 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
| 129 |
+
return fbank
|
| 130 |
+
|
| 131 |
+
def extract_features(
|
| 132 |
+
self,
|
| 133 |
+
source: torch.Tensor,
|
| 134 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 135 |
+
fbank_mean: float = 15.41663,
|
| 136 |
+
fbank_std: float = 6.55582,
|
| 137 |
+
feature_only=False,
|
| 138 |
+
):
|
| 139 |
+
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32)
|
| 140 |
+
|
| 141 |
+
if padding_mask is not None:
|
| 142 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
| 143 |
+
|
| 144 |
+
fbank = fbank.unsqueeze(1)
|
| 145 |
+
features = self.patch_embedding(fbank)
|
| 146 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
| 147 |
+
features = features.transpose(1, 2)
|
| 148 |
+
features = self.layer_norm(features)
|
| 149 |
+
|
| 150 |
+
if padding_mask is not None:
|
| 151 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 152 |
+
|
| 153 |
+
if self.post_extract_proj is not None:
|
| 154 |
+
features = self.post_extract_proj(features)
|
| 155 |
+
|
| 156 |
+
x = self.dropout_input(features)
|
| 157 |
+
|
| 158 |
+
x, layer_results = self.encoder(
|
| 159 |
+
x,
|
| 160 |
+
padding_mask=padding_mask,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if not feature_only and self.predictor is not None:
|
| 164 |
+
x = self.predictor_dropout(x)
|
| 165 |
+
logits = self.predictor(x)
|
| 166 |
+
|
| 167 |
+
if padding_mask is not None and padding_mask.any():
|
| 168 |
+
logits[padding_mask] = 0
|
| 169 |
+
logits = logits.sum(dim=1)
|
| 170 |
+
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
| 171 |
+
else:
|
| 172 |
+
logits = logits.mean(dim=1)
|
| 173 |
+
|
| 174 |
+
lprobs = torch.sigmoid(logits)
|
| 175 |
+
|
| 176 |
+
return lprobs, padding_mask
|
| 177 |
+
else:
|
| 178 |
+
return x, padding_mask
|
| 179 |
+
|
| 180 |
+
def forward(self, source: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
| 181 |
+
return self.extract_features(source, padding_mask, feature_only=True)
|
NatureLM/models/beats/Tokenizers.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
| 17 |
+
from torch.nn import LayerNorm
|
| 18 |
+
|
| 19 |
+
from .backbone import (
|
| 20 |
+
TransformerEncoder,
|
| 21 |
+
)
|
| 22 |
+
from .quantizer import (
|
| 23 |
+
NormEMAVectorQuantizer,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TokenizersConfig:
|
| 30 |
+
def __init__(self, cfg=None):
|
| 31 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
| 32 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
| 33 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
| 34 |
+
|
| 35 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
| 36 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
| 37 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
| 38 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
| 39 |
+
self.activation_fn: str = "gelu" # activation function to use
|
| 40 |
+
|
| 41 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
| 42 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
| 43 |
+
|
| 44 |
+
# dropouts
|
| 45 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
| 46 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
| 47 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
| 48 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
| 49 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
| 50 |
+
|
| 51 |
+
# positional embeddings
|
| 52 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
| 53 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
| 54 |
+
|
| 55 |
+
# relative position embedding
|
| 56 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
| 57 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
| 58 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
| 59 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
| 60 |
+
|
| 61 |
+
# quantizer
|
| 62 |
+
self.quant_n: int = 1024 # codebook number in quantizer
|
| 63 |
+
self.quant_dim: int = 256 # codebook dimension in quantizer
|
| 64 |
+
|
| 65 |
+
if cfg is not None:
|
| 66 |
+
self.update(cfg)
|
| 67 |
+
|
| 68 |
+
def update(self, cfg: dict):
|
| 69 |
+
self.__dict__.update(cfg)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Tokenizers(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
cfg: TokenizersConfig,
|
| 76 |
+
) -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
| 79 |
+
|
| 80 |
+
self.cfg = cfg
|
| 81 |
+
|
| 82 |
+
self.embed = cfg.embed_dim
|
| 83 |
+
self.post_extract_proj = (
|
| 84 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.input_patch_size = cfg.input_patch_size
|
| 88 |
+
self.patch_embedding = nn.Conv2d(
|
| 89 |
+
1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=cfg.conv_bias
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 93 |
+
|
| 94 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
| 95 |
+
self.encoder = TransformerEncoder(cfg)
|
| 96 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 97 |
+
|
| 98 |
+
self.quantize = NormEMAVectorQuantizer(
|
| 99 |
+
n_embed=cfg.quant_n,
|
| 100 |
+
embedding_dim=cfg.quant_dim,
|
| 101 |
+
beta=1.0,
|
| 102 |
+
kmeans_init=True,
|
| 103 |
+
decay=0.99,
|
| 104 |
+
)
|
| 105 |
+
self.quant_n = cfg.quant_n
|
| 106 |
+
self.quantize_layer = nn.Sequential(
|
| 107 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
| 108 |
+
nn.Tanh(),
|
| 109 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim), # for quantize
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward_padding_mask(
|
| 113 |
+
self,
|
| 114 |
+
features: torch.Tensor,
|
| 115 |
+
padding_mask: torch.Tensor,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 118 |
+
if extra > 0:
|
| 119 |
+
padding_mask = padding_mask[:, :-extra]
|
| 120 |
+
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
| 121 |
+
padding_mask = padding_mask.all(-1)
|
| 122 |
+
return padding_mask
|
| 123 |
+
|
| 124 |
+
def preprocess(
|
| 125 |
+
self,
|
| 126 |
+
source: torch.Tensor,
|
| 127 |
+
fbank_mean: float = 15.41663,
|
| 128 |
+
fbank_std: float = 6.55582,
|
| 129 |
+
) -> torch.Tensor:
|
| 130 |
+
fbanks = []
|
| 131 |
+
for waveform in source:
|
| 132 |
+
waveform = waveform.unsqueeze(0) * 2**15
|
| 133 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
| 134 |
+
fbanks.append(fbank)
|
| 135 |
+
fbank = torch.stack(fbanks, dim=0)
|
| 136 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
| 137 |
+
return fbank
|
| 138 |
+
|
| 139 |
+
def extract_labels(
|
| 140 |
+
self,
|
| 141 |
+
source: torch.Tensor,
|
| 142 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 143 |
+
fbank_mean: float = 15.41663,
|
| 144 |
+
fbank_std: float = 6.55582,
|
| 145 |
+
):
|
| 146 |
+
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
| 147 |
+
|
| 148 |
+
if padding_mask is not None:
|
| 149 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
| 150 |
+
|
| 151 |
+
fbank = fbank.unsqueeze(1)
|
| 152 |
+
features = self.patch_embedding(fbank)
|
| 153 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
| 154 |
+
features = features.transpose(1, 2)
|
| 155 |
+
features = self.layer_norm(features)
|
| 156 |
+
|
| 157 |
+
if padding_mask is not None:
|
| 158 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 159 |
+
|
| 160 |
+
if self.post_extract_proj is not None:
|
| 161 |
+
features = self.post_extract_proj(features)
|
| 162 |
+
|
| 163 |
+
x = self.dropout_input(features)
|
| 164 |
+
|
| 165 |
+
x, layer_results = self.encoder(
|
| 166 |
+
x,
|
| 167 |
+
padding_mask=padding_mask,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
quantize_input = self.quantize_layer(x)
|
| 171 |
+
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
| 172 |
+
|
| 173 |
+
return embed_ind
|
NatureLM/models/beats/__init__.py
ADDED
|
File without changes
|
NatureLM/models/beats/__pycache__/BEATs.cpython-310.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
NatureLM/models/beats/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
NatureLM/models/beats/__pycache__/backbone.cpython-310.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
NatureLM/models/beats/__pycache__/modules.cpython-310.pyc
ADDED
|
Binary file (6.14 kB). View file
|
|
|
NatureLM/models/beats/backbone.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from typing import Dict, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import Tensor, nn
|
| 17 |
+
from torch.nn import LayerNorm, Parameter
|
| 18 |
+
|
| 19 |
+
from .modules import (
|
| 20 |
+
GLU_Linear,
|
| 21 |
+
GradMultiply,
|
| 22 |
+
SamePad,
|
| 23 |
+
get_activation_fn,
|
| 24 |
+
quant_noise,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TransformerEncoder(nn.Module):
|
| 29 |
+
def __init__(self, args):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.dropout = args.dropout
|
| 33 |
+
self.embedding_dim = args.encoder_embed_dim
|
| 34 |
+
|
| 35 |
+
self.pos_conv = nn.Conv1d(
|
| 36 |
+
self.embedding_dim,
|
| 37 |
+
self.embedding_dim,
|
| 38 |
+
kernel_size=args.conv_pos,
|
| 39 |
+
padding=args.conv_pos // 2,
|
| 40 |
+
groups=args.conv_pos_groups,
|
| 41 |
+
)
|
| 42 |
+
dropout = 0
|
| 43 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
| 44 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
| 45 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
| 46 |
+
|
| 47 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
| 48 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
| 49 |
+
|
| 50 |
+
if hasattr(args, "relative_position_embedding"):
|
| 51 |
+
self.relative_position_embedding = args.relative_position_embedding
|
| 52 |
+
self.num_buckets = args.num_buckets
|
| 53 |
+
self.max_distance = args.max_distance
|
| 54 |
+
else:
|
| 55 |
+
self.relative_position_embedding = False
|
| 56 |
+
self.num_buckets = 0
|
| 57 |
+
self.max_distance = 0
|
| 58 |
+
|
| 59 |
+
self.layers = nn.ModuleList(
|
| 60 |
+
[
|
| 61 |
+
TransformerSentenceEncoderLayer(
|
| 62 |
+
embedding_dim=self.embedding_dim,
|
| 63 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 64 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 65 |
+
dropout=self.dropout,
|
| 66 |
+
attention_dropout=args.attention_dropout,
|
| 67 |
+
activation_dropout=args.activation_dropout,
|
| 68 |
+
activation_fn=args.activation_fn,
|
| 69 |
+
layer_norm_first=args.layer_norm_first,
|
| 70 |
+
deep_norm=args.deep_norm,
|
| 71 |
+
has_relative_attention_bias=self.relative_position_embedding,
|
| 72 |
+
num_buckets=self.num_buckets,
|
| 73 |
+
max_distance=self.max_distance,
|
| 74 |
+
gru_rel_pos=args.gru_rel_pos,
|
| 75 |
+
encoder_layers=args.encoder_layers,
|
| 76 |
+
)
|
| 77 |
+
for i in range(args.encoder_layers)
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
if self.relative_position_embedding:
|
| 81 |
+
for i in range(1, args.encoder_layers):
|
| 82 |
+
del self.layers[i].self_attn.relative_attention_bias
|
| 83 |
+
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
| 84 |
+
|
| 85 |
+
self.layer_norm_first = args.layer_norm_first
|
| 86 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
| 87 |
+
self.layerdrop = args.encoder_layerdrop
|
| 88 |
+
|
| 89 |
+
self.apply(init_bert_params)
|
| 90 |
+
|
| 91 |
+
if args.deep_norm:
|
| 92 |
+
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
| 93 |
+
for i in range(args.encoder_layers):
|
| 94 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
| 95 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
| 96 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
| 97 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
| 98 |
+
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
| 99 |
+
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
| 100 |
+
|
| 101 |
+
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
| 102 |
+
|
| 103 |
+
def forward(self, x, padding_mask=None, layer=None):
|
| 104 |
+
x, layer_results = self.extract_features(x, padding_mask, layer)
|
| 105 |
+
|
| 106 |
+
if self.layer_norm_first and layer is None:
|
| 107 |
+
x = self.layer_norm(x)
|
| 108 |
+
|
| 109 |
+
return x, layer_results
|
| 110 |
+
|
| 111 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
| 112 |
+
if padding_mask is not None:
|
| 113 |
+
x[padding_mask] = 0
|
| 114 |
+
|
| 115 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
| 116 |
+
x_conv = x_conv.transpose(1, 2)
|
| 117 |
+
x = x + x_conv
|
| 118 |
+
|
| 119 |
+
if not self.layer_norm_first:
|
| 120 |
+
x = self.layer_norm(x)
|
| 121 |
+
|
| 122 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 123 |
+
|
| 124 |
+
# B x T x C -> T x B x C
|
| 125 |
+
x = x.transpose(0, 1)
|
| 126 |
+
|
| 127 |
+
layer_results = []
|
| 128 |
+
z = None
|
| 129 |
+
if tgt_layer is not None:
|
| 130 |
+
layer_results.append((x, z))
|
| 131 |
+
r = None
|
| 132 |
+
pos_bias = None
|
| 133 |
+
for i, layer in enumerate(self.layers):
|
| 134 |
+
if self.layer_wise_gradient_decay_ratio != 1.0:
|
| 135 |
+
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
| 136 |
+
dropout_probability = np.random.random()
|
| 137 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
| 138 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
| 139 |
+
if tgt_layer is not None:
|
| 140 |
+
layer_results.append((x, z))
|
| 141 |
+
if i == tgt_layer:
|
| 142 |
+
r = x
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
if r is not None:
|
| 146 |
+
x = r
|
| 147 |
+
|
| 148 |
+
# T x B x C -> B x T x C
|
| 149 |
+
x = x.transpose(0, 1)
|
| 150 |
+
|
| 151 |
+
return x, layer_results
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
embedding_dim: float = 768,
|
| 158 |
+
ffn_embedding_dim: float = 3072,
|
| 159 |
+
num_attention_heads: float = 8,
|
| 160 |
+
dropout: float = 0.1,
|
| 161 |
+
attention_dropout: float = 0.1,
|
| 162 |
+
activation_dropout: float = 0.1,
|
| 163 |
+
activation_fn: str = "relu",
|
| 164 |
+
layer_norm_first: bool = False,
|
| 165 |
+
deep_norm: bool = False,
|
| 166 |
+
has_relative_attention_bias: bool = False,
|
| 167 |
+
num_buckets: int = 0,
|
| 168 |
+
max_distance: int = 0,
|
| 169 |
+
rescale_init: bool = False,
|
| 170 |
+
gru_rel_pos: bool = False,
|
| 171 |
+
encoder_layers: int = 0,
|
| 172 |
+
) -> None:
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.embedding_dim = embedding_dim
|
| 175 |
+
self.dropout = dropout
|
| 176 |
+
self.activation_dropout = activation_dropout
|
| 177 |
+
|
| 178 |
+
self.activation_name = activation_fn
|
| 179 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
| 180 |
+
self.self_attn = MultiheadAttention(
|
| 181 |
+
self.embedding_dim,
|
| 182 |
+
num_attention_heads,
|
| 183 |
+
dropout=attention_dropout,
|
| 184 |
+
self_attention=True,
|
| 185 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
| 186 |
+
num_buckets=num_buckets,
|
| 187 |
+
max_distance=max_distance,
|
| 188 |
+
rescale_init=rescale_init,
|
| 189 |
+
gru_rel_pos=gru_rel_pos,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 193 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
| 194 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 195 |
+
|
| 196 |
+
self.layer_norm_first = layer_norm_first
|
| 197 |
+
|
| 198 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
| 199 |
+
|
| 200 |
+
if self.activation_name == "glu":
|
| 201 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
| 202 |
+
else:
|
| 203 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
| 204 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
| 205 |
+
|
| 206 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
| 207 |
+
|
| 208 |
+
self.deep_norm = deep_norm
|
| 209 |
+
if self.deep_norm:
|
| 210 |
+
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
| 211 |
+
else:
|
| 212 |
+
self.deep_norm_alpha = 1
|
| 213 |
+
|
| 214 |
+
def forward(
|
| 215 |
+
self,
|
| 216 |
+
x: torch.Tensor,
|
| 217 |
+
self_attn_mask: torch.Tensor = None,
|
| 218 |
+
self_attn_padding_mask: torch.Tensor = None,
|
| 219 |
+
need_weights: bool = False,
|
| 220 |
+
pos_bias=None,
|
| 221 |
+
):
|
| 222 |
+
residual = x
|
| 223 |
+
|
| 224 |
+
if self.layer_norm_first:
|
| 225 |
+
x = self.self_attn_layer_norm(x)
|
| 226 |
+
x, attn, pos_bias = self.self_attn(
|
| 227 |
+
query=x,
|
| 228 |
+
key=x,
|
| 229 |
+
value=x,
|
| 230 |
+
key_padding_mask=self_attn_padding_mask,
|
| 231 |
+
need_weights=False,
|
| 232 |
+
attn_mask=self_attn_mask,
|
| 233 |
+
position_bias=pos_bias,
|
| 234 |
+
)
|
| 235 |
+
x = self.dropout1(x)
|
| 236 |
+
x = residual + x
|
| 237 |
+
|
| 238 |
+
residual = x
|
| 239 |
+
x = self.final_layer_norm(x)
|
| 240 |
+
if self.activation_name == "glu":
|
| 241 |
+
x = self.fc1(x)
|
| 242 |
+
else:
|
| 243 |
+
x = self.activation_fn(self.fc1(x))
|
| 244 |
+
x = self.dropout2(x)
|
| 245 |
+
x = self.fc2(x)
|
| 246 |
+
x = self.dropout3(x)
|
| 247 |
+
x = residual + x
|
| 248 |
+
else:
|
| 249 |
+
x, attn, pos_bias = self.self_attn(
|
| 250 |
+
query=x,
|
| 251 |
+
key=x,
|
| 252 |
+
value=x,
|
| 253 |
+
key_padding_mask=self_attn_padding_mask,
|
| 254 |
+
need_weights=need_weights,
|
| 255 |
+
attn_mask=self_attn_mask,
|
| 256 |
+
position_bias=pos_bias,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
x = self.dropout1(x)
|
| 260 |
+
x = residual * self.deep_norm_alpha + x
|
| 261 |
+
|
| 262 |
+
x = self.self_attn_layer_norm(x)
|
| 263 |
+
|
| 264 |
+
residual = x
|
| 265 |
+
if self.activation_name == "glu":
|
| 266 |
+
x = self.fc1(x)
|
| 267 |
+
else:
|
| 268 |
+
x = self.activation_fn(self.fc1(x))
|
| 269 |
+
x = self.dropout2(x)
|
| 270 |
+
x = self.fc2(x)
|
| 271 |
+
x = self.dropout3(x)
|
| 272 |
+
x = residual * self.deep_norm_alpha + x
|
| 273 |
+
x = self.final_layer_norm(x)
|
| 274 |
+
|
| 275 |
+
return x, attn, pos_bias
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class MultiheadAttention(nn.Module):
|
| 279 |
+
"""Multi-headed attention.
|
| 280 |
+
|
| 281 |
+
See "Attention Is All You Need" for more details.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
embed_dim,
|
| 287 |
+
num_heads,
|
| 288 |
+
kdim=None,
|
| 289 |
+
vdim=None,
|
| 290 |
+
dropout=0.0,
|
| 291 |
+
bias=True,
|
| 292 |
+
add_bias_kv=False,
|
| 293 |
+
add_zero_attn=False,
|
| 294 |
+
self_attention=False,
|
| 295 |
+
encoder_decoder_attention=False,
|
| 296 |
+
q_noise=0.0,
|
| 297 |
+
qn_block_size=8,
|
| 298 |
+
has_relative_attention_bias=False,
|
| 299 |
+
num_buckets=32,
|
| 300 |
+
max_distance=128,
|
| 301 |
+
gru_rel_pos=False,
|
| 302 |
+
rescale_init=False,
|
| 303 |
+
):
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.embed_dim = embed_dim
|
| 306 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 307 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 308 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 309 |
+
|
| 310 |
+
self.num_heads = num_heads
|
| 311 |
+
self.dropout_module = nn.Dropout(dropout)
|
| 312 |
+
|
| 313 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 314 |
+
self.num_buckets = num_buckets
|
| 315 |
+
self.max_distance = max_distance
|
| 316 |
+
if self.has_relative_attention_bias:
|
| 317 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
| 318 |
+
|
| 319 |
+
self.head_dim = embed_dim // num_heads
|
| 320 |
+
self.q_head_dim = self.head_dim
|
| 321 |
+
self.k_head_dim = self.head_dim
|
| 322 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 323 |
+
self.scaling = self.head_dim**-0.5
|
| 324 |
+
|
| 325 |
+
self.self_attention = self_attention
|
| 326 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 327 |
+
|
| 328 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 329 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
k_bias = True
|
| 333 |
+
if rescale_init:
|
| 334 |
+
k_bias = False
|
| 335 |
+
|
| 336 |
+
k_embed_dim = embed_dim
|
| 337 |
+
q_embed_dim = embed_dim
|
| 338 |
+
|
| 339 |
+
self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
|
| 340 |
+
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 341 |
+
self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
|
| 342 |
+
|
| 343 |
+
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 344 |
+
|
| 345 |
+
if add_bias_kv:
|
| 346 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 347 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 348 |
+
else:
|
| 349 |
+
self.bias_k = self.bias_v = None
|
| 350 |
+
|
| 351 |
+
self.add_zero_attn = add_zero_attn
|
| 352 |
+
|
| 353 |
+
self.gru_rel_pos = gru_rel_pos
|
| 354 |
+
if self.gru_rel_pos:
|
| 355 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
| 356 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
| 357 |
+
|
| 358 |
+
self.reset_parameters()
|
| 359 |
+
|
| 360 |
+
def reset_parameters(self):
|
| 361 |
+
if self.qkv_same_dim:
|
| 362 |
+
# Empirically observed the convergence to be much better with
|
| 363 |
+
# the scaled initialization
|
| 364 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 365 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 366 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 367 |
+
else:
|
| 368 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 369 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 370 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 371 |
+
|
| 372 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 373 |
+
if self.out_proj.bias is not None:
|
| 374 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 375 |
+
if self.bias_k is not None:
|
| 376 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 377 |
+
if self.bias_v is not None:
|
| 378 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 379 |
+
if self.has_relative_attention_bias:
|
| 380 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
| 381 |
+
|
| 382 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
| 383 |
+
num_buckets = self.num_buckets
|
| 384 |
+
max_distance = self.max_distance
|
| 385 |
+
relative_buckets = 0
|
| 386 |
+
|
| 387 |
+
if bidirectional:
|
| 388 |
+
num_buckets = num_buckets // 2
|
| 389 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
| 390 |
+
relative_positions = torch.abs(relative_positions)
|
| 391 |
+
else:
|
| 392 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
| 393 |
+
|
| 394 |
+
max_exact = num_buckets // 2
|
| 395 |
+
is_small = relative_positions < max_exact
|
| 396 |
+
|
| 397 |
+
relative_postion_if_large = max_exact + (
|
| 398 |
+
torch.log(relative_positions.float() / max_exact)
|
| 399 |
+
/ math.log(max_distance / max_exact)
|
| 400 |
+
* (num_buckets - max_exact)
|
| 401 |
+
).to(torch.long)
|
| 402 |
+
relative_postion_if_large = torch.min(
|
| 403 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
| 407 |
+
return relative_buckets
|
| 408 |
+
|
| 409 |
+
def compute_bias(self, query_length, key_length):
|
| 410 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 411 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 412 |
+
relative_position = memory_position - context_position
|
| 413 |
+
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
| 414 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
| 415 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
| 416 |
+
values = values.permute([2, 0, 1])
|
| 417 |
+
return values
|
| 418 |
+
|
| 419 |
+
def forward(
|
| 420 |
+
self,
|
| 421 |
+
query,
|
| 422 |
+
key: Optional[Tensor],
|
| 423 |
+
value: Optional[Tensor],
|
| 424 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 425 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 426 |
+
need_weights: bool = True,
|
| 427 |
+
static_kv: bool = False,
|
| 428 |
+
attn_mask: Optional[Tensor] = None,
|
| 429 |
+
before_softmax: bool = False,
|
| 430 |
+
need_head_weights: bool = False,
|
| 431 |
+
position_bias: Optional[Tensor] = None,
|
| 432 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 433 |
+
"""Input shape: Time x Batch x Channel
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 437 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 438 |
+
padding elements are indicated by 1s.
|
| 439 |
+
need_weights (bool, optional): return the attention weights,
|
| 440 |
+
averaged over heads (default: False).
|
| 441 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 442 |
+
implement causal attention, where the mask prevents the
|
| 443 |
+
attention from looking forward in time (default: None).
|
| 444 |
+
before_softmax (bool, optional): return the raw attention
|
| 445 |
+
weights and values before the attention softmax.
|
| 446 |
+
need_head_weights (bool, optional): return the attention
|
| 447 |
+
weights for each head. Implies *need_weights*. Default:
|
| 448 |
+
return the average attention weights over all heads.
|
| 449 |
+
"""
|
| 450 |
+
if need_head_weights:
|
| 451 |
+
need_weights = True
|
| 452 |
+
|
| 453 |
+
is_tpu = query.device.type == "xla"
|
| 454 |
+
|
| 455 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 456 |
+
src_len = tgt_len
|
| 457 |
+
assert embed_dim == self.embed_dim
|
| 458 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 459 |
+
if key is not None:
|
| 460 |
+
src_len, key_bsz, _ = key.size()
|
| 461 |
+
if not torch.jit.is_scripting():
|
| 462 |
+
assert key_bsz == bsz
|
| 463 |
+
assert value is not None
|
| 464 |
+
assert src_len, bsz == value.shape[:2]
|
| 465 |
+
|
| 466 |
+
if self.has_relative_attention_bias and position_bias is None:
|
| 467 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
| 468 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
| 469 |
+
|
| 470 |
+
if incremental_state is not None:
|
| 471 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 472 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 473 |
+
# previous time steps are cached - no need to recompute
|
| 474 |
+
# key and value if they are static
|
| 475 |
+
if static_kv:
|
| 476 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 477 |
+
key = value = None
|
| 478 |
+
else:
|
| 479 |
+
saved_state = None
|
| 480 |
+
|
| 481 |
+
if self.self_attention:
|
| 482 |
+
q = self.q_proj(query)
|
| 483 |
+
k = self.k_proj(query)
|
| 484 |
+
v = self.v_proj(query)
|
| 485 |
+
elif self.encoder_decoder_attention:
|
| 486 |
+
# encoder-decoder attention
|
| 487 |
+
q = self.q_proj(query)
|
| 488 |
+
if key is None:
|
| 489 |
+
assert value is None
|
| 490 |
+
k = v = None
|
| 491 |
+
else:
|
| 492 |
+
k = self.k_proj(key)
|
| 493 |
+
v = self.v_proj(key)
|
| 494 |
+
|
| 495 |
+
else:
|
| 496 |
+
assert key is not None and value is not None
|
| 497 |
+
q = self.q_proj(query)
|
| 498 |
+
k = self.k_proj(key)
|
| 499 |
+
v = self.v_proj(value)
|
| 500 |
+
q *= self.scaling
|
| 501 |
+
alpha = 32
|
| 502 |
+
q *= 1 / alpha
|
| 503 |
+
|
| 504 |
+
if self.bias_k is not None:
|
| 505 |
+
assert self.bias_v is not None
|
| 506 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 507 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 508 |
+
if attn_mask is not None:
|
| 509 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
| 510 |
+
if key_padding_mask is not None:
|
| 511 |
+
key_padding_mask = torch.cat(
|
| 512 |
+
[
|
| 513 |
+
key_padding_mask,
|
| 514 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 515 |
+
],
|
| 516 |
+
dim=1,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
|
| 520 |
+
if k is not None:
|
| 521 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
|
| 522 |
+
if v is not None:
|
| 523 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 524 |
+
|
| 525 |
+
if saved_state is not None:
|
| 526 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 527 |
+
if "prev_key" in saved_state:
|
| 528 |
+
_prev_key = saved_state["prev_key"]
|
| 529 |
+
assert _prev_key is not None
|
| 530 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 531 |
+
if static_kv:
|
| 532 |
+
k = prev_key
|
| 533 |
+
else:
|
| 534 |
+
assert k is not None
|
| 535 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 536 |
+
src_len = k.size(1)
|
| 537 |
+
if "prev_value" in saved_state:
|
| 538 |
+
_prev_value = saved_state["prev_value"]
|
| 539 |
+
assert _prev_value is not None
|
| 540 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 541 |
+
if static_kv:
|
| 542 |
+
v = prev_value
|
| 543 |
+
else:
|
| 544 |
+
assert v is not None
|
| 545 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 546 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 547 |
+
if "prev_key_padding_mask" in saved_state:
|
| 548 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 549 |
+
assert k is not None and v is not None
|
| 550 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
| 551 |
+
key_padding_mask=key_padding_mask,
|
| 552 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 553 |
+
batch_size=bsz,
|
| 554 |
+
src_len=k.size(1),
|
| 555 |
+
static_kv=static_kv,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 559 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 560 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 561 |
+
# In this branch incremental_state is never None
|
| 562 |
+
assert incremental_state is not None
|
| 563 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 564 |
+
assert k is not None
|
| 565 |
+
assert k.size(1) == src_len
|
| 566 |
+
|
| 567 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 568 |
+
# not supporting Optional types.
|
| 569 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 570 |
+
key_padding_mask = None
|
| 571 |
+
|
| 572 |
+
if key_padding_mask is not None:
|
| 573 |
+
assert key_padding_mask.size(0) == bsz
|
| 574 |
+
assert key_padding_mask.size(1) == src_len
|
| 575 |
+
|
| 576 |
+
if self.add_zero_attn:
|
| 577 |
+
assert v is not None
|
| 578 |
+
src_len += 1
|
| 579 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 580 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 581 |
+
if attn_mask is not None:
|
| 582 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
| 583 |
+
if key_padding_mask is not None:
|
| 584 |
+
key_padding_mask = torch.cat(
|
| 585 |
+
[
|
| 586 |
+
key_padding_mask,
|
| 587 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
| 588 |
+
],
|
| 589 |
+
dim=1,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 593 |
+
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
| 594 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 595 |
+
|
| 596 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 597 |
+
|
| 598 |
+
if attn_mask is not None:
|
| 599 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 600 |
+
attn_weights += attn_mask
|
| 601 |
+
|
| 602 |
+
if key_padding_mask is not None:
|
| 603 |
+
# don't attend to padding symbols
|
| 604 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 605 |
+
if not is_tpu:
|
| 606 |
+
attn_weights = attn_weights.masked_fill(
|
| 607 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
| 608 |
+
float("-inf"),
|
| 609 |
+
)
|
| 610 |
+
else:
|
| 611 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 612 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
| 613 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 614 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 615 |
+
|
| 616 |
+
if before_softmax:
|
| 617 |
+
return attn_weights, v, position_bias
|
| 618 |
+
|
| 619 |
+
if position_bias is not None:
|
| 620 |
+
attn_mask_rel_pos = position_bias
|
| 621 |
+
if self.gru_rel_pos == 1:
|
| 622 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
| 623 |
+
_B, _H, _L, __ = query_layer.size()
|
| 624 |
+
gate_a, gate_b = torch.sigmoid(
|
| 625 |
+
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
| 626 |
+
).chunk(2, dim=-1)
|
| 627 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 628 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
| 629 |
+
|
| 630 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
| 631 |
+
|
| 632 |
+
attn_weights = attn_weights + attn_mask_rel_pos
|
| 633 |
+
|
| 634 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
| 635 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 636 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 637 |
+
|
| 638 |
+
assert v is not None
|
| 639 |
+
attn = torch.bmm(attn_probs, v)
|
| 640 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 641 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 642 |
+
attn = self.out_proj(attn)
|
| 643 |
+
attn_weights: Optional[Tensor] = None
|
| 644 |
+
if need_weights:
|
| 645 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
| 646 |
+
if not need_head_weights:
|
| 647 |
+
# average attention weights over heads
|
| 648 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 649 |
+
|
| 650 |
+
return attn, attn_weights, position_bias
|
| 651 |
+
|
| 652 |
+
@staticmethod
|
| 653 |
+
def _append_prev_key_padding_mask(
|
| 654 |
+
key_padding_mask: Optional[Tensor],
|
| 655 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 656 |
+
batch_size: int,
|
| 657 |
+
src_len: int,
|
| 658 |
+
static_kv: bool,
|
| 659 |
+
) -> Optional[Tensor]:
|
| 660 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 661 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 662 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 663 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 664 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
| 665 |
+
# During incremental decoding, as the padding token enters and
|
| 666 |
+
# leaves the frame, there will be a time when prev or current
|
| 667 |
+
# is None
|
| 668 |
+
elif prev_key_padding_mask is not None:
|
| 669 |
+
if src_len > prev_key_padding_mask.size(1):
|
| 670 |
+
filler = torch.zeros(
|
| 671 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 672 |
+
device=prev_key_padding_mask.device,
|
| 673 |
+
)
|
| 674 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
| 675 |
+
else:
|
| 676 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
| 677 |
+
elif key_padding_mask is not None:
|
| 678 |
+
if src_len > key_padding_mask.size(1):
|
| 679 |
+
filler = torch.zeros(
|
| 680 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 681 |
+
device=key_padding_mask.device,
|
| 682 |
+
)
|
| 683 |
+
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
| 684 |
+
else:
|
| 685 |
+
new_key_padding_mask = key_padding_mask.float()
|
| 686 |
+
else:
|
| 687 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 688 |
+
return new_key_padding_mask
|
| 689 |
+
|
| 690 |
+
def _get_input_buffer(
|
| 691 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 692 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 693 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 694 |
+
if result is not None:
|
| 695 |
+
return result
|
| 696 |
+
else:
|
| 697 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 698 |
+
return empty_result
|
| 699 |
+
|
| 700 |
+
def _set_input_buffer(
|
| 701 |
+
self,
|
| 702 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 703 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 704 |
+
):
|
| 705 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 706 |
+
|
| 707 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 708 |
+
return attn_weights
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def init_bert_params(module):
|
| 712 |
+
"""
|
| 713 |
+
Initialize the weights specific to the BERT Model.
|
| 714 |
+
This overrides the default initializations depending on the specified arguments.
|
| 715 |
+
1. If normal_init_linear_weights is set then weights of linear
|
| 716 |
+
layer will be initialized using the normal distribution and
|
| 717 |
+
bais will be set to the specified value.
|
| 718 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
| 719 |
+
layer will be initialized using the normal distribution.
|
| 720 |
+
3. If normal_init_proj_weights is set then weights of
|
| 721 |
+
in_project_weight for MultiHeadAttention initialized using
|
| 722 |
+
the normal distribution (to be validated).
|
| 723 |
+
"""
|
| 724 |
+
|
| 725 |
+
def normal_(data):
|
| 726 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
| 727 |
+
# so that the RNG is consistent with and without FSDP
|
| 728 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
| 729 |
+
|
| 730 |
+
if isinstance(module, nn.Linear):
|
| 731 |
+
normal_(module.weight.data)
|
| 732 |
+
if module.bias is not None:
|
| 733 |
+
module.bias.data.zero_()
|
| 734 |
+
if isinstance(module, nn.Embedding):
|
| 735 |
+
normal_(module.weight.data)
|
| 736 |
+
if module.padding_idx is not None:
|
| 737 |
+
module.weight.data[module.padding_idx].zero_()
|
| 738 |
+
if isinstance(module, MultiheadAttention):
|
| 739 |
+
normal_(module.q_proj.weight.data)
|
| 740 |
+
normal_(module.k_proj.weight.data)
|
| 741 |
+
normal_(module.v_proj.weight.data)
|
NatureLM/models/beats/modules.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GradMultiply(torch.autograd.Function):
|
| 19 |
+
@staticmethod
|
| 20 |
+
def forward(ctx, x, scale):
|
| 21 |
+
ctx.scale = scale
|
| 22 |
+
res = x.new(x)
|
| 23 |
+
return res
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def backward(ctx, grad):
|
| 27 |
+
return grad * ctx.scale, None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SamePad(nn.Module):
|
| 31 |
+
def __init__(self, kernel_size, causal=False):
|
| 32 |
+
super().__init__()
|
| 33 |
+
if causal:
|
| 34 |
+
self.remove = kernel_size - 1
|
| 35 |
+
else:
|
| 36 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
if self.remove > 0:
|
| 40 |
+
x = x[:, :, : -self.remove]
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Swish(nn.Module):
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super(Swish, self).__init__()
|
| 47 |
+
self.act = torch.nn.Sigmoid()
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return x * self.act(x)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class GLU_Linear(nn.Module):
|
| 54 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
| 55 |
+
super(GLU_Linear, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.glu_type = glu_type
|
| 58 |
+
self.output_dim = output_dim
|
| 59 |
+
|
| 60 |
+
if glu_type == "sigmoid":
|
| 61 |
+
self.glu_act = torch.nn.Sigmoid()
|
| 62 |
+
elif glu_type == "swish":
|
| 63 |
+
self.glu_act = Swish()
|
| 64 |
+
elif glu_type == "relu":
|
| 65 |
+
self.glu_act = torch.nn.ReLU()
|
| 66 |
+
elif glu_type == "gelu":
|
| 67 |
+
self.glu_act = torch.nn.GELU()
|
| 68 |
+
|
| 69 |
+
if bias_in_glu:
|
| 70 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
| 71 |
+
else:
|
| 72 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
| 76 |
+
x = self.linear(x)
|
| 77 |
+
|
| 78 |
+
if self.glu_type == "bilinear":
|
| 79 |
+
x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
|
| 80 |
+
else:
|
| 81 |
+
x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
|
| 82 |
+
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def gelu_accurate(x):
|
| 87 |
+
if not hasattr(gelu_accurate, "_a"):
|
| 88 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 89 |
+
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_activation_fn(activation: str):
|
| 97 |
+
"""Returns the activation function corresponding to `activation`"""
|
| 98 |
+
|
| 99 |
+
if activation == "relu":
|
| 100 |
+
return F.relu
|
| 101 |
+
elif activation == "gelu":
|
| 102 |
+
return gelu
|
| 103 |
+
elif activation == "gelu_fast":
|
| 104 |
+
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
| 105 |
+
return gelu_accurate
|
| 106 |
+
elif activation == "gelu_accurate":
|
| 107 |
+
return gelu_accurate
|
| 108 |
+
elif activation == "tanh":
|
| 109 |
+
return torch.tanh
|
| 110 |
+
elif activation == "linear":
|
| 111 |
+
return lambda x: x
|
| 112 |
+
elif activation == "glu":
|
| 113 |
+
return lambda x: x
|
| 114 |
+
else:
|
| 115 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def quant_noise(module, p, block_size):
|
| 119 |
+
"""
|
| 120 |
+
Wraps modules and applies quantization noise to the weights for
|
| 121 |
+
subsequent quantization with Iterative Product Quantization as
|
| 122 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
- module: nn.Module
|
| 126 |
+
- p: amount of Quantization Noise
|
| 127 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
| 128 |
+
|
| 129 |
+
Remarks:
|
| 130 |
+
- Module weights must have the right sizes wrt the block size
|
| 131 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
| 132 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
| 133 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
| 134 |
+
- We implement the simplest form of noise here as stated in the paper
|
| 135 |
+
which consists in randomly dropping blocks
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# if no quantization noise, don't register hook
|
| 139 |
+
if p <= 0:
|
| 140 |
+
return module
|
| 141 |
+
|
| 142 |
+
# supported modules
|
| 143 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
| 144 |
+
|
| 145 |
+
# test whether module.weight has the right sizes wrt block_size
|
| 146 |
+
is_conv = module.weight.ndim == 4
|
| 147 |
+
|
| 148 |
+
# 2D matrix
|
| 149 |
+
if not is_conv:
|
| 150 |
+
assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
|
| 151 |
+
|
| 152 |
+
# 4D matrix
|
| 153 |
+
else:
|
| 154 |
+
# 1x1 convolutions
|
| 155 |
+
if module.kernel_size == (1, 1):
|
| 156 |
+
assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
|
| 157 |
+
# regular convolutions
|
| 158 |
+
else:
|
| 159 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
| 160 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
| 161 |
+
|
| 162 |
+
def _forward_pre_hook(mod, input):
|
| 163 |
+
# no noise for evaluation
|
| 164 |
+
if mod.training:
|
| 165 |
+
if not is_conv:
|
| 166 |
+
# gather weight and sizes
|
| 167 |
+
weight = mod.weight
|
| 168 |
+
in_features = weight.size(1)
|
| 169 |
+
out_features = weight.size(0)
|
| 170 |
+
|
| 171 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 172 |
+
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
| 173 |
+
mask.bernoulli_(p)
|
| 174 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 175 |
+
|
| 176 |
+
else:
|
| 177 |
+
# gather weight and sizes
|
| 178 |
+
weight = mod.weight
|
| 179 |
+
in_channels = mod.in_channels
|
| 180 |
+
out_channels = mod.out_channels
|
| 181 |
+
|
| 182 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 183 |
+
if mod.kernel_size == (1, 1):
|
| 184 |
+
mask = torch.zeros(
|
| 185 |
+
int(in_channels // block_size * out_channels),
|
| 186 |
+
device=weight.device,
|
| 187 |
+
)
|
| 188 |
+
mask.bernoulli_(p)
|
| 189 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 190 |
+
else:
|
| 191 |
+
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
| 192 |
+
mask.bernoulli_(p)
|
| 193 |
+
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
| 194 |
+
|
| 195 |
+
# scale weights and apply mask
|
| 196 |
+
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
|
| 197 |
+
s = 1 / (1 - p)
|
| 198 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 199 |
+
|
| 200 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
| 201 |
+
return module
|
NatureLM/models/beats/quantizer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on VQGAN code bases
|
| 7 |
+
# https://github.com/CompVis/taming-transformers
|
| 8 |
+
# --------------------------------------------------------'
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as distributed
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
except ImportError:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def l2norm(t):
|
| 22 |
+
return F.normalize(t, p=2, dim=-1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ema_inplace(moving_avg, new, decay):
|
| 26 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def sample_vectors(samples, num):
|
| 30 |
+
num_samples, device = samples.shape[0], samples.device
|
| 31 |
+
|
| 32 |
+
if num_samples >= num:
|
| 33 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 34 |
+
else:
|
| 35 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 36 |
+
|
| 37 |
+
return samples[indices]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
| 41 |
+
dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device
|
| 42 |
+
|
| 43 |
+
means = sample_vectors(samples, num_clusters)
|
| 44 |
+
|
| 45 |
+
for _ in range(num_iters):
|
| 46 |
+
if use_cosine_sim:
|
| 47 |
+
dists = samples @ means.t()
|
| 48 |
+
else:
|
| 49 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
| 50 |
+
dists = -(diffs**2).sum(dim=-1)
|
| 51 |
+
|
| 52 |
+
buckets = dists.max(dim=-1).indices
|
| 53 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 54 |
+
zero_mask = bins == 0
|
| 55 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 56 |
+
|
| 57 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 58 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 59 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 60 |
+
|
| 61 |
+
if use_cosine_sim:
|
| 62 |
+
new_means = l2norm(new_means)
|
| 63 |
+
|
| 64 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 65 |
+
|
| 66 |
+
return means, bins
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class EmbeddingEMA(nn.Module):
|
| 70 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=""):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.num_tokens = num_tokens
|
| 73 |
+
self.codebook_dim = codebook_dim
|
| 74 |
+
self.decay = decay
|
| 75 |
+
self.eps = eps
|
| 76 |
+
if codebook_init_path == "":
|
| 77 |
+
if not kmeans_init:
|
| 78 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
| 79 |
+
weight = l2norm(weight)
|
| 80 |
+
else:
|
| 81 |
+
weight = torch.zeros(num_tokens, codebook_dim)
|
| 82 |
+
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
| 83 |
+
else:
|
| 84 |
+
print(f"load init codebook weight from {codebook_init_path}")
|
| 85 |
+
codebook_ckpt_weight = torch.load(codebook_init_path, map_location="cpu")
|
| 86 |
+
weight = codebook_ckpt_weight.clone()
|
| 87 |
+
self.register_buffer("initted", torch.Tensor([True]))
|
| 88 |
+
|
| 89 |
+
self.weight = nn.Parameter(weight, requires_grad=False)
|
| 90 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
| 91 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
| 92 |
+
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
| 93 |
+
self.update = True
|
| 94 |
+
|
| 95 |
+
@torch.jit.ignore
|
| 96 |
+
def init_embed_(self, data):
|
| 97 |
+
if self.initted:
|
| 98 |
+
return
|
| 99 |
+
print("Performing Kemans init for codebook")
|
| 100 |
+
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
| 101 |
+
self.weight.data.copy_(embed)
|
| 102 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 103 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
| 104 |
+
|
| 105 |
+
def forward(self, embed_id):
|
| 106 |
+
return F.embedding(embed_id, self.weight)
|
| 107 |
+
|
| 108 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
| 109 |
+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
| 110 |
+
|
| 111 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
| 112 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
| 113 |
+
|
| 114 |
+
def weight_update(self, num_tokens):
|
| 115 |
+
n = self.cluster_size.sum()
|
| 116 |
+
smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
| 117 |
+
# normalize embedding average with smoothed cluster size
|
| 118 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
| 119 |
+
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
| 120 |
+
self.weight.data.copy_(embed_normalized)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def norm_ema_inplace(moving_avg, new, decay):
|
| 124 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 125 |
+
moving_avg.data.copy_(l2norm(moving_avg.data))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class NormEMAVectorQuantizer(nn.Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
n_embed,
|
| 132 |
+
embedding_dim,
|
| 133 |
+
beta,
|
| 134 |
+
decay=0.99,
|
| 135 |
+
eps=1e-5,
|
| 136 |
+
statistic_code_usage=True,
|
| 137 |
+
kmeans_init=False,
|
| 138 |
+
codebook_init_path="",
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.codebook_dim = embedding_dim
|
| 142 |
+
self.num_tokens = n_embed
|
| 143 |
+
self.beta = beta
|
| 144 |
+
self.decay = decay
|
| 145 |
+
|
| 146 |
+
# learnable = True if orthogonal_reg_weight > 0 else False
|
| 147 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
|
| 148 |
+
|
| 149 |
+
self.statistic_code_usage = statistic_code_usage
|
| 150 |
+
if statistic_code_usage:
|
| 151 |
+
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
| 152 |
+
if distributed.is_available() and distributed.is_initialized():
|
| 153 |
+
print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
|
| 154 |
+
self.all_reduce_fn = distributed.all_reduce
|
| 155 |
+
else:
|
| 156 |
+
self.all_reduce_fn = nn.Identity()
|
| 157 |
+
|
| 158 |
+
def reset_cluster_size(self, device):
|
| 159 |
+
if self.statistic_code_usage:
|
| 160 |
+
self.register_buffer("cluster_size", torch.zeros(self.num_tokens))
|
| 161 |
+
self.cluster_size = self.cluster_size.to(device)
|
| 162 |
+
|
| 163 |
+
def forward(self, z):
|
| 164 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 165 |
+
# z, 'b c h w -> b h w c'
|
| 166 |
+
# z = rearrange(z, 'b c h w -> b h w c')
|
| 167 |
+
# z = z.transpose(1, 2)
|
| 168 |
+
z = l2norm(z)
|
| 169 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
| 170 |
+
|
| 171 |
+
self.embedding.init_embed_(z_flattened)
|
| 172 |
+
|
| 173 |
+
d = (
|
| 174 |
+
z_flattened.pow(2).sum(dim=1, keepdim=True)
|
| 175 |
+
+ self.embedding.weight.pow(2).sum(dim=1)
|
| 176 |
+
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
|
| 177 |
+
) # 'n d -> d n'
|
| 178 |
+
|
| 179 |
+
encoding_indices = torch.argmin(d, dim=1)
|
| 180 |
+
|
| 181 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
| 182 |
+
|
| 183 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
| 184 |
+
|
| 185 |
+
if not self.training:
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
cluster_size = encodings.sum(0)
|
| 188 |
+
self.all_reduce_fn(cluster_size)
|
| 189 |
+
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
| 190 |
+
|
| 191 |
+
if self.training and self.embedding.update:
|
| 192 |
+
# EMA cluster size
|
| 193 |
+
|
| 194 |
+
bins = encodings.sum(0)
|
| 195 |
+
self.all_reduce_fn(bins)
|
| 196 |
+
|
| 197 |
+
# self.embedding.cluster_size_ema_update(bins)
|
| 198 |
+
ema_inplace(self.cluster_size, bins, self.decay)
|
| 199 |
+
|
| 200 |
+
zero_mask = bins == 0
|
| 201 |
+
bins = bins.masked_fill(zero_mask, 1.0)
|
| 202 |
+
|
| 203 |
+
embed_sum = z_flattened.t() @ encodings
|
| 204 |
+
self.all_reduce_fn(embed_sum)
|
| 205 |
+
|
| 206 |
+
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
| 207 |
+
embed_normalized = l2norm(embed_normalized)
|
| 208 |
+
|
| 209 |
+
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, embed_normalized)
|
| 210 |
+
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
|
| 211 |
+
|
| 212 |
+
# compute loss for embedding
|
| 213 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
| 214 |
+
|
| 215 |
+
# preserve gradients
|
| 216 |
+
z_q = z + (z_q - z).detach()
|
| 217 |
+
|
| 218 |
+
# reshape back to match original input shape
|
| 219 |
+
# z_q, 'b h w c -> b c h w'
|
| 220 |
+
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
| 221 |
+
# z_q = z_q.transpose(1, 2)
|
| 222 |
+
return z_q, loss, encoding_indices
|
NatureLM/models/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Earth Species Project
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import StoppingCriteria
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
| 20 |
+
def __init__(self, stops=[], encounters=1):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.stops = stops
|
| 23 |
+
|
| 24 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 25 |
+
for stop in self.stops:
|
| 26 |
+
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
return False
|
NatureLM/optims.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This script is from https://github.com/salesforce/LAVIS/blob/main/lavis/common/optims.py
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from NatureLM.config import OptimizerConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LinearWarmupStepLRScheduler:
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
optimizer,
|
| 15 |
+
max_epoch,
|
| 16 |
+
min_lr,
|
| 17 |
+
init_lr,
|
| 18 |
+
decay_rate=1,
|
| 19 |
+
warmup_start_lr=-1,
|
| 20 |
+
warmup_steps=0,
|
| 21 |
+
**kwargs,
|
| 22 |
+
):
|
| 23 |
+
self.optimizer = optimizer
|
| 24 |
+
|
| 25 |
+
self.max_epoch = max_epoch
|
| 26 |
+
self.min_lr = min_lr
|
| 27 |
+
|
| 28 |
+
self.decay_rate = decay_rate
|
| 29 |
+
|
| 30 |
+
self.init_lr = init_lr
|
| 31 |
+
self.warmup_steps = warmup_steps
|
| 32 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
| 33 |
+
|
| 34 |
+
def step(self, cur_epoch, cur_step):
|
| 35 |
+
if cur_epoch == 0:
|
| 36 |
+
warmup_lr_schedule(
|
| 37 |
+
step=cur_step,
|
| 38 |
+
optimizer=self.optimizer,
|
| 39 |
+
max_step=self.warmup_steps,
|
| 40 |
+
init_lr=self.warmup_start_lr,
|
| 41 |
+
max_lr=self.init_lr,
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
step_lr_schedule(
|
| 45 |
+
epoch=cur_epoch,
|
| 46 |
+
optimizer=self.optimizer,
|
| 47 |
+
init_lr=self.init_lr,
|
| 48 |
+
min_lr=self.min_lr,
|
| 49 |
+
decay_rate=self.decay_rate,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class LinearWarmupCosineLRScheduler:
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
optimizer,
|
| 57 |
+
max_epoch,
|
| 58 |
+
iters_per_epoch,
|
| 59 |
+
min_lr,
|
| 60 |
+
init_lr,
|
| 61 |
+
warmup_steps=0,
|
| 62 |
+
warmup_start_lr=-1,
|
| 63 |
+
**kwargs,
|
| 64 |
+
):
|
| 65 |
+
self.optimizer = optimizer
|
| 66 |
+
|
| 67 |
+
self.max_epoch = max_epoch
|
| 68 |
+
self.iters_per_epoch = iters_per_epoch
|
| 69 |
+
self.min_lr = min_lr
|
| 70 |
+
|
| 71 |
+
self.init_lr = init_lr
|
| 72 |
+
self.warmup_steps = warmup_steps
|
| 73 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
| 74 |
+
|
| 75 |
+
def step(self, cur_epoch, cur_step):
|
| 76 |
+
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
| 77 |
+
if total_cur_step < self.warmup_steps:
|
| 78 |
+
warmup_lr_schedule(
|
| 79 |
+
step=cur_step,
|
| 80 |
+
optimizer=self.optimizer,
|
| 81 |
+
max_step=self.warmup_steps,
|
| 82 |
+
init_lr=self.warmup_start_lr,
|
| 83 |
+
max_lr=self.init_lr,
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
cosine_lr_schedule(
|
| 87 |
+
epoch=total_cur_step,
|
| 88 |
+
optimizer=self.optimizer,
|
| 89 |
+
max_epoch=self.max_epoch * self.iters_per_epoch,
|
| 90 |
+
init_lr=self.init_lr,
|
| 91 |
+
min_lr=self.min_lr,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
| 96 |
+
"""Decay the learning rate"""
|
| 97 |
+
lr = (init_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * epoch / max_epoch)) + min_lr
|
| 98 |
+
for param_group in optimizer.param_groups:
|
| 99 |
+
param_group["lr"] = lr
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
| 103 |
+
"""Warmup the learning rate"""
|
| 104 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
| 105 |
+
for param_group in optimizer.param_groups:
|
| 106 |
+
param_group["lr"] = lr
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
| 110 |
+
"""Decay the learning rate"""
|
| 111 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
| 112 |
+
for param_group in optimizer.param_groups:
|
| 113 |
+
param_group["lr"] = lr
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_optimizer(model, config: OptimizerConfig):
|
| 117 |
+
num_parameters = 0
|
| 118 |
+
p_wd, p_non_wd = [], []
|
| 119 |
+
for n, p in model.named_parameters():
|
| 120 |
+
if not p.requires_grad:
|
| 121 |
+
continue # frozen weights
|
| 122 |
+
print(n)
|
| 123 |
+
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
|
| 124 |
+
p_non_wd.append(p)
|
| 125 |
+
else:
|
| 126 |
+
p_wd.append(p)
|
| 127 |
+
num_parameters += p.data.nelement()
|
| 128 |
+
logging.info("number of trainable parameters: %d" % num_parameters)
|
| 129 |
+
optim_params = [
|
| 130 |
+
{
|
| 131 |
+
"params": p_wd,
|
| 132 |
+
"weight_decay": float(config.weight_decay),
|
| 133 |
+
},
|
| 134 |
+
{"params": p_non_wd, "weight_decay": 0},
|
| 135 |
+
]
|
| 136 |
+
beta2 = config.beta2
|
| 137 |
+
if config.device == "cpu":
|
| 138 |
+
optimizer = torch.optim.AdamW(
|
| 139 |
+
optim_params,
|
| 140 |
+
lr=float(config.init_lr),
|
| 141 |
+
weight_decay=float(config.weight_decay),
|
| 142 |
+
betas=(0.9, beta2),
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
import bitsandbytes as bnb
|
| 146 |
+
|
| 147 |
+
optimizer = bnb.optim.PagedAdamW8bit(
|
| 148 |
+
optim_params,
|
| 149 |
+
lr=float(config.init_lr),
|
| 150 |
+
weight_decay=float(config.weight_decay),
|
| 151 |
+
betas=(0.9, beta2),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return optimizer
|
NatureLM/processors.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module contains the audio and text processor for NatureLM-audio inference and evaluation"""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import resampy
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class NatureLMAudioProcessor:
|
| 15 |
+
"""Preprocess samples to make them ready for NatureLM-audio inference.
|
| 16 |
+
|
| 17 |
+
Arguments
|
| 18 |
+
---------
|
| 19 |
+
naturelm_sample_rate : int
|
| 20 |
+
The sample rate of the NatureLM model
|
| 21 |
+
max_length_seconds : int
|
| 22 |
+
The maximum length of audio in seconds
|
| 23 |
+
audio_token_placeholder : str
|
| 24 |
+
The placeholder for the audio token in the instruction
|
| 25 |
+
prompt_template : str
|
| 26 |
+
The template for the prompt. The instruction or query from the user is inserted in the placeholder at {prompt}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
Examples
|
| 30 |
+
--------
|
| 31 |
+
>>> processor = NatureLMAudioProcessor()
|
| 32 |
+
>>> audios = [np.random.rand(32000), np.random.rand(32000)]
|
| 33 |
+
>>> instructions = ["What is the weather today?", "What is the time now?"]
|
| 34 |
+
>>> input_sample_rates = [32000, 32000]
|
| 35 |
+
>>> audios, instructions = processor(audios, instructions, input_sample_rates)
|
| 36 |
+
>>> audios.shape == (2, 160000)
|
| 37 |
+
True
|
| 38 |
+
>>> "<Audio><AudioHere></Audio> " in instructions[0]
|
| 39 |
+
True
|
| 40 |
+
>>> "<|start_header_id|>user<|end_header_id|>" in instructions[0]
|
| 41 |
+
True
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
sample_rate: int = 16000
|
| 45 |
+
max_length_seconds: int = 10
|
| 46 |
+
audio_token_placeholder: str = "<Audio><AudioHere></Audio> "
|
| 47 |
+
prompt_template: str = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 48 |
+
|
| 49 |
+
def prepare_audio(self, audio: list[float] | np.ndarray | os.PathLike, input_sr: int = None) -> torch.Tensor:
|
| 50 |
+
"""Prepare an audio array or file path for inference"""
|
| 51 |
+
if isinstance(audio, str | os.PathLike):
|
| 52 |
+
audio, sr = sf.read(audio)
|
| 53 |
+
input_sr = sr
|
| 54 |
+
elif isinstance(audio, list):
|
| 55 |
+
audio = np.array(audio)
|
| 56 |
+
|
| 57 |
+
assert isinstance(audio, np.ndarray), "Audio not a numpy array"
|
| 58 |
+
|
| 59 |
+
# Convert stereo to mono
|
| 60 |
+
if len(audio.shape) == 2:
|
| 61 |
+
# find the smaller axis as channel dim to avg over (like (2, T) or (T, 2), 2 = channel dim
|
| 62 |
+
axis_to_average = int(np.argmin(audio.shape))
|
| 63 |
+
audio = audio.mean(axis=axis_to_average)
|
| 64 |
+
|
| 65 |
+
# Resample
|
| 66 |
+
if input_sr is not None and input_sr != self.sample_rate:
|
| 67 |
+
# audio = torchaudio.functional.resample(
|
| 68 |
+
# torch.from_numpy(audio), orig_freq=input_sr, new_freq=self.sample_rate
|
| 69 |
+
# )
|
| 70 |
+
audio = resampy.resample(audio, input_sr, self.sample_rate)
|
| 71 |
+
audio = torch.from_numpy(audio.squeeze())
|
| 72 |
+
else:
|
| 73 |
+
audio = torch.from_numpy(audio)
|
| 74 |
+
|
| 75 |
+
# Truncate audio to at most max_length_seconds
|
| 76 |
+
audio = audio[: self.sample_rate * self.max_length_seconds]
|
| 77 |
+
|
| 78 |
+
# Pad to max_length_seconds if short
|
| 79 |
+
if len(audio) < self.sample_rate * self.max_length_seconds:
|
| 80 |
+
pad_size = self.sample_rate * self.max_length_seconds - len(audio)
|
| 81 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 82 |
+
|
| 83 |
+
# Clamp
|
| 84 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 85 |
+
|
| 86 |
+
return audio.squeeze()
|
| 87 |
+
|
| 88 |
+
def prepare_instruction(self, instruction: str) -> str:
|
| 89 |
+
"""Add the audio token placeholder to the instruction and format it
|
| 90 |
+
according to the llama tokenizer.
|
| 91 |
+
"""
|
| 92 |
+
if self.audio_token_placeholder not in instruction:
|
| 93 |
+
instruction = self.audio_token_placeholder + instruction
|
| 94 |
+
instruction = self.prompt_template.format(prompt=instruction.strip())
|
| 95 |
+
|
| 96 |
+
return instruction
|
| 97 |
+
|
| 98 |
+
def __call__(
|
| 99 |
+
self,
|
| 100 |
+
audios: list[list[float] | np.ndarray] | list[str | os.PathLike],
|
| 101 |
+
instructions: list[str],
|
| 102 |
+
input_sample_rates: list[int],
|
| 103 |
+
) -> tuple[torch.Tensor, list[str]]:
|
| 104 |
+
"""Prepare audios and instructions for inference
|
| 105 |
+
|
| 106 |
+
Arguments
|
| 107 |
+
---------
|
| 108 |
+
audios : list[list[float] | np.ndarray] | list[str | os.PathLike]
|
| 109 |
+
The audio samples or file paths
|
| 110 |
+
instructions : list[str]
|
| 111 |
+
The instructions or queries
|
| 112 |
+
input_sample_rates : list[int]
|
| 113 |
+
The sample rates of the input audio samples
|
| 114 |
+
|
| 115 |
+
Returns
|
| 116 |
+
-------
|
| 117 |
+
tuple[torch.Tensor, list[str]]
|
| 118 |
+
The prepared audios and instructions
|
| 119 |
+
"""
|
| 120 |
+
audios = torch.stack(
|
| 121 |
+
[self.prepare_audio(audio, input_sr) for audio, input_sr in zip(audios, input_sample_rates)]
|
| 122 |
+
)
|
| 123 |
+
instructions = [self.prepare_instruction(instruction) for instruction in instructions]
|
| 124 |
+
|
| 125 |
+
return audios, instructions
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass
|
| 129 |
+
class NatureLMAudioEvalProcessor(NatureLMAudioProcessor):
|
| 130 |
+
"""Preprocess samples to make them ready for NatureLM-audio evaluation on BEANS-Zero dataset.
|
| 131 |
+
This requires a few additional parameters compared to the NatureLMAudioProcessor.
|
| 132 |
+
|
| 133 |
+
Arguments
|
| 134 |
+
---------
|
| 135 |
+
naturelm_sample_rate : int
|
| 136 |
+
The sample rate of the NatureLM model
|
| 137 |
+
max_length_seconds : int
|
| 138 |
+
The maximum length of audio in seconds
|
| 139 |
+
audio_token_placeholder : str
|
| 140 |
+
The placeholder for the audio token in the instruction
|
| 141 |
+
prompt_template : str
|
| 142 |
+
The template for the prompt. The instruction or query from the user is inserted in the placeholder at {prompt}
|
| 143 |
+
|
| 144 |
+
dataset_name : list[str]
|
| 145 |
+
The name of the dataset being processed
|
| 146 |
+
true_labels : list[str]
|
| 147 |
+
The true labels or expected outputs for the samples.
|
| 148 |
+
task: str
|
| 149 |
+
The task for the dataset. Can be 'detection', 'captioning', or 'classification'
|
| 150 |
+
threshold_too_many_detection_labels : int
|
| 151 |
+
The threshold for the number of labels in the dataset to switch to a detection prompt. Default is 8.
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
Examples
|
| 155 |
+
--------
|
| 156 |
+
>>> processor = NatureLMAudioEvalProcessor(task="detection", true_labels=["dog", "cat", "bird", "None", "mouse", "elephant", "lion", "tiger", "bear"])
|
| 157 |
+
>>> audios = [np.random.rand(32000), np.random.rand(32000)]
|
| 158 |
+
>>> instructions = ["What is the weather today?", "What is the time now?"]
|
| 159 |
+
>>> input_sample_rates = [32000, 32000]
|
| 160 |
+
>>> audios, instructions = processor(audios, instructions, input_sample_rates)
|
| 161 |
+
>>> audios.shape == (2, 160000)
|
| 162 |
+
True
|
| 163 |
+
>>> "<Audio><AudioHere></Audio> " in instructions[0]
|
| 164 |
+
True
|
| 165 |
+
>>> "<|start_header_id|>user<|end_header_id|>" in instructions[0]
|
| 166 |
+
True
|
| 167 |
+
>>> "What are the common names" in instructions[0]
|
| 168 |
+
True
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
dataset_name: str = "beans-zero"
|
| 172 |
+
true_labels: list[str] = field(default_factory=lambda _: [])
|
| 173 |
+
task: str = "detection"
|
| 174 |
+
|
| 175 |
+
threshold_too_many_detection_labels: int = 8
|
| 176 |
+
|
| 177 |
+
def __post_init__(self):
|
| 178 |
+
self.detection_prompt: str = (
|
| 179 |
+
"<Audio><AudioHere></Audio> What are the common names for the species in the audio, if any?"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# find the unique labels in the dataset
|
| 183 |
+
self.dataset_labels = set(self.true_labels)
|
| 184 |
+
if self.task == "detection":
|
| 185 |
+
self.dataset_labels.add("None")
|
| 186 |
+
if self.task == "captioning":
|
| 187 |
+
self.dataset_labels = set()
|
| 188 |
+
|
| 189 |
+
def prepare_instruction(self, instruction: str) -> str:
|
| 190 |
+
"""Add the audio token placeholder to the instruction and format it"""
|
| 191 |
+
if self.task == "detection" and len(self.dataset_labels) > self.threshold_too_many_detection_labels:
|
| 192 |
+
instruction = self.detection_prompt
|
| 193 |
+
|
| 194 |
+
if self.audio_token_placeholder not in instruction:
|
| 195 |
+
instruction = self.audio_token_placeholder + instruction
|
| 196 |
+
|
| 197 |
+
instruction = self.prompt_template.format(prompt=instruction.strip())
|
| 198 |
+
|
| 199 |
+
return instruction
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class NatureLMInferenceDataset(torch.utils.data.Dataset):
|
| 203 |
+
"""A pytorch dataset for batched inference with NatureLM-audio
|
| 204 |
+
|
| 205 |
+
TODO: currently, if the batch contains very different prompts the model doesnt work well.
|
| 206 |
+
|
| 207 |
+
Arguments
|
| 208 |
+
---------
|
| 209 |
+
ds : datasets.Dataset
|
| 210 |
+
The huggingface dataset containing the samples
|
| 211 |
+
|
| 212 |
+
Examples
|
| 213 |
+
--------
|
| 214 |
+
TODO: Add examples
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
def __init__(self, ds, processor):
|
| 218 |
+
self.ds = ds
|
| 219 |
+
self.processor = processor
|
| 220 |
+
|
| 221 |
+
def __getitem__(self, idx):
|
| 222 |
+
sample = self.ds[idx]
|
| 223 |
+
input_sample_rate = json.loads(sample["metadata"])["sample_rate"]
|
| 224 |
+
audio_tensor = self.processor.prepare_audio(sample["audio"], input_sample_rate)
|
| 225 |
+
|
| 226 |
+
instruction = self.processor.prepare_instruction(sample["instruction"])
|
| 227 |
+
return {
|
| 228 |
+
"raw_wav": audio_tensor,
|
| 229 |
+
"text": "",
|
| 230 |
+
"task": sample["task"],
|
| 231 |
+
"audio_chunk_sizes": len(audio_tensor),
|
| 232 |
+
"index": idx,
|
| 233 |
+
"id": sample["id"],
|
| 234 |
+
"prompt": instruction,
|
| 235 |
+
"label": sample["output"],
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
def __len__(self):
|
| 239 |
+
return len(self.ds)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def collater(samples: list[dict]) -> dict:
|
| 243 |
+
"""Collate samples into a batch.
|
| 244 |
+
|
| 245 |
+
Samples is a list of dictionaries, each containing the following keys:
|
| 246 |
+
- raw_wav: a list of tensors containing the raw audio waveform
|
| 247 |
+
- text: a list of strings containing the text
|
| 248 |
+
- task: a list of strings containing the task
|
| 249 |
+
- id: a list of strings containing the id
|
| 250 |
+
- prompt: a list of strings containing the prompt
|
| 251 |
+
- index: a list of integers containing the index
|
| 252 |
+
- audio_chunk_sizes: a list of integers containing the size of each audio chunk
|
| 253 |
+
|
| 254 |
+
The indiviudal audio waveforms will be stacked along the batch dimension for easier
|
| 255 |
+
processing in the audio model. To keep which audio belongs to which sample, we add
|
| 256 |
+
the audio_chunk_sizes key to the batch dictionary.
|
| 257 |
+
"""
|
| 258 |
+
raw_wav = torch.stack([s["raw_wav"] for s in samples])
|
| 259 |
+
paddding_mask = torch.zeros_like(raw_wav).to(torch.bool)
|
| 260 |
+
|
| 261 |
+
text = [s["text"] for s in samples]
|
| 262 |
+
prompt = [s["prompt"] for s in samples]
|
| 263 |
+
task = [s["task"] for s in samples]
|
| 264 |
+
id = [s["id"] for s in samples]
|
| 265 |
+
index = [s["index"] for s in samples]
|
| 266 |
+
label = [s["label"] for s in samples]
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
"raw_wav": raw_wav,
|
| 270 |
+
"padding_mask": paddding_mask,
|
| 271 |
+
"text": text,
|
| 272 |
+
"task": task,
|
| 273 |
+
"id": id,
|
| 274 |
+
"prompt": prompt,
|
| 275 |
+
"index": index,
|
| 276 |
+
"audio_chunk_sizes": 1,
|
| 277 |
+
"label": label,
|
| 278 |
+
}
|
NatureLM/runner.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This script is based on https://github.com/salesforce/LAVIS/blob/main/lavis/runners/runner_base.py
|
| 2 |
+
|
| 3 |
+
import datetime
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
import wandb
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 16 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 17 |
+
|
| 18 |
+
from NatureLM.config import Config
|
| 19 |
+
from NatureLM.dist_utils import get_rank, get_world_size, is_dist_avail_and_initialized, is_main_process, main_process
|
| 20 |
+
from NatureLM.logger import MetricLogger, SmoothedValue
|
| 21 |
+
from NatureLM.optims import LinearWarmupCosineLRScheduler, get_optimizer
|
| 22 |
+
from NatureLM.task_metrics import get_task_metrics
|
| 23 |
+
from NatureLM.utils import get_dataloader, prepare_sample_dist
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Runner:
|
| 27 |
+
def __init__(self, cfg: Config, model, datasets, job_id):
|
| 28 |
+
self.config = cfg
|
| 29 |
+
|
| 30 |
+
# log
|
| 31 |
+
device = "cuda:0"
|
| 32 |
+
if is_main_process():
|
| 33 |
+
if self.config.run.wandb_enabled:
|
| 34 |
+
wandb.init(project="earthlm", config=self.config.model_dump())
|
| 35 |
+
else:
|
| 36 |
+
wandb.init(mode="disabled")
|
| 37 |
+
|
| 38 |
+
if "LOCAL_RANK" in os.environ:
|
| 39 |
+
device = int(os.environ["LOCAL_RANK"])
|
| 40 |
+
else:
|
| 41 |
+
device = self.config.run.device
|
| 42 |
+
print(f"device is {device} could have been {self.config.run.device}")
|
| 43 |
+
self.output_dir = Path(self.config.run.output_dir) / job_id
|
| 44 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
self.log_writter = SummaryWriter(self.output_dir)
|
| 46 |
+
|
| 47 |
+
# settings
|
| 48 |
+
self.device = torch.device(device)
|
| 49 |
+
self.use_distributed = self.config.run.use_distributed
|
| 50 |
+
self.start_epoch = 0
|
| 51 |
+
self.max_epoch = self.config.run.optims.max_epoch
|
| 52 |
+
self.evaluate_only = self.config.run.evaluate
|
| 53 |
+
self.cuda_enabled = self.device.type == "cuda"
|
| 54 |
+
|
| 55 |
+
# test prompt
|
| 56 |
+
self.prompt_template = self.config.model.prompt_template
|
| 57 |
+
|
| 58 |
+
# model
|
| 59 |
+
self._model = model
|
| 60 |
+
torch.nn.SyncBatchNorm.convert_sync_batchnorm(self._model)
|
| 61 |
+
self._model.to(self.device)
|
| 62 |
+
if self.use_distributed:
|
| 63 |
+
self.model = DDP(
|
| 64 |
+
self._model,
|
| 65 |
+
find_unused_parameters=True,
|
| 66 |
+
static_graph=False,
|
| 67 |
+
device_ids=[self.device],
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
self.model = self._model
|
| 71 |
+
|
| 72 |
+
# dataloaders
|
| 73 |
+
self.train_loader = get_dataloader(
|
| 74 |
+
datasets["train"],
|
| 75 |
+
self.config.run,
|
| 76 |
+
is_train=True,
|
| 77 |
+
use_distributed=self.use_distributed,
|
| 78 |
+
)
|
| 79 |
+
self.valid_loader = get_dataloader(
|
| 80 |
+
datasets["valid"],
|
| 81 |
+
self.config.run,
|
| 82 |
+
is_train=False,
|
| 83 |
+
use_distributed=self.use_distributed,
|
| 84 |
+
)
|
| 85 |
+
self.test_loader = get_dataloader(
|
| 86 |
+
datasets["test"],
|
| 87 |
+
self.config.run,
|
| 88 |
+
is_train=False,
|
| 89 |
+
use_distributed=self.use_distributed,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# scaler
|
| 93 |
+
self.use_amp = self.config.run.amp
|
| 94 |
+
if self.use_amp:
|
| 95 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
| 96 |
+
else:
|
| 97 |
+
self.scaler = None
|
| 98 |
+
|
| 99 |
+
# optimizer & scheduler
|
| 100 |
+
self.iters_per_epoch = (
|
| 101 |
+
len(self.train_loader) if self.config.run.epoch_based else self.config.run.iters_per_epoch
|
| 102 |
+
)
|
| 103 |
+
self.optimizer = get_optimizer(self.model, self.config.run.optims)
|
| 104 |
+
self.scheduler = LinearWarmupCosineLRScheduler(
|
| 105 |
+
self.optimizer,
|
| 106 |
+
max_epoch=self.max_epoch,
|
| 107 |
+
iters_per_epoch=self.iters_per_epoch,
|
| 108 |
+
min_lr=self.config.run.optims.min_lr,
|
| 109 |
+
init_lr=self.config.run.optims.init_lr,
|
| 110 |
+
warmup_steps=self.config.run.optims.warmup_steps,
|
| 111 |
+
warmup_start_lr=self.config.run.optims.warmup_start_lr,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
#### augmentations
|
| 115 |
+
# self.rng = random.Random(self.config.run.seed)
|
| 116 |
+
# self.rngnp = np.random.default_rng(seed=self.config.run.seed)
|
| 117 |
+
# self.rngth = torch.Generator(device=args.device)
|
| 118 |
+
# self.rngth.manual_seed(self.config.run.seed)
|
| 119 |
+
# augments = []
|
| 120 |
+
# if self.config.run.augmentations.flip:
|
| 121 |
+
# augments.append(augmentations.Flip(self.config.run.augmentations.flip, rngth=self.rngth, seed=self.config.run.seed))
|
| 122 |
+
# if self.config.run.augmentations.bandmask:
|
| 123 |
+
# augments.append(augmentations.BandMask(self.config.run.augmentations.bandmask, sample_rate=args.sample_rate, rng=self.rng, seed=self.config.run.seed))
|
| 124 |
+
# if self.config.run.augmentations.revecho:
|
| 125 |
+
# augments.append(
|
| 126 |
+
# augmentations.RevEcho(proba=self.config.run.augmentations.revecho,rng=self.rng,seed=self.config.run.seed))
|
| 127 |
+
# self.augment = torch.nn.Sequential(*augments)
|
| 128 |
+
|
| 129 |
+
self.log_config()
|
| 130 |
+
|
| 131 |
+
def unwrap_dist_model(self, model):
|
| 132 |
+
if self.use_distributed:
|
| 133 |
+
return model.module
|
| 134 |
+
else:
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
def train_epoch(self, epoch):
|
| 138 |
+
self.model.train()
|
| 139 |
+
|
| 140 |
+
metric_logger = MetricLogger(delimiter=" ")
|
| 141 |
+
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 142 |
+
metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
|
| 143 |
+
|
| 144 |
+
logging.info("Start training epoch {}, {} iters per inner epoch.".format(epoch, self.iters_per_epoch))
|
| 145 |
+
header = "Train: data epoch: [{}]".format(epoch)
|
| 146 |
+
|
| 147 |
+
# Get gradient clipping parameters from config
|
| 148 |
+
clip_grad_norm = self.config.run.optims.max_grad_norm
|
| 149 |
+
clip_grad_value = self.config.run.optims.max_grad_value
|
| 150 |
+
|
| 151 |
+
for i in metric_logger.log_every(
|
| 152 |
+
range(self.iters_per_epoch),
|
| 153 |
+
self.config.run.log_freq,
|
| 154 |
+
header=header,
|
| 155 |
+
logger=self.log_writter,
|
| 156 |
+
start_step=epoch * self.iters_per_epoch,
|
| 157 |
+
):
|
| 158 |
+
if i >= self.iters_per_epoch:
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
samples = next(self.train_loader)
|
| 162 |
+
|
| 163 |
+
samples = prepare_sample_dist(samples, self.device)
|
| 164 |
+
|
| 165 |
+
#### augmentation
|
| 166 |
+
# if False:
|
| 167 |
+
# samples = self.augment(samples)
|
| 168 |
+
|
| 169 |
+
self.scheduler.step(cur_epoch=epoch, cur_step=i)
|
| 170 |
+
|
| 171 |
+
with torch.autocast(self.device.type, enabled=self.use_amp, dtype=torch.bfloat16):
|
| 172 |
+
loss = self.model(samples)["loss"]
|
| 173 |
+
if torch.isnan(loss):
|
| 174 |
+
print("loss nan", samples)
|
| 175 |
+
# continue
|
| 176 |
+
|
| 177 |
+
if self.use_amp and self.scaler:
|
| 178 |
+
self.scaler.scale(loss).backward()
|
| 179 |
+
else:
|
| 180 |
+
loss.backward()
|
| 181 |
+
|
| 182 |
+
# Apply gradient clipping
|
| 183 |
+
if clip_grad_norm is not None:
|
| 184 |
+
if self.use_amp and self.scaler:
|
| 185 |
+
self.scaler.unscale_(self.optimizer)
|
| 186 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad_norm)
|
| 187 |
+
if clip_grad_value is not None:
|
| 188 |
+
if self.use_amp and self.scaler:
|
| 189 |
+
self.scaler.unscale_(self.optimizer)
|
| 190 |
+
torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=clip_grad_value)
|
| 191 |
+
|
| 192 |
+
if (i + 1) % self.config.run.accum_grad_iters == 0:
|
| 193 |
+
if self.use_amp and self.scaler:
|
| 194 |
+
self.scaler.step(self.optimizer)
|
| 195 |
+
self.scaler.update()
|
| 196 |
+
else:
|
| 197 |
+
self.optimizer.step()
|
| 198 |
+
self.optimizer.zero_grad()
|
| 199 |
+
|
| 200 |
+
metric_logger.update(loss=loss.item())
|
| 201 |
+
metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
|
| 202 |
+
|
| 203 |
+
metric_logger.synchronize_between_processes()
|
| 204 |
+
logging.info("Averaged stats: " + str(metric_logger.global_avg()))
|
| 205 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def valid_epoch(self, epoch, split, decode=True, save_json=False, decode_ratio=1.0):
|
| 209 |
+
"""
|
| 210 |
+
Decode = True will lead to calculation of custom metrics which are based on text.
|
| 211 |
+
decode_ratio controls the percentage of batches which will have custom metrics computed,
|
| 212 |
+
a speed trade-off due to the cost of the 'generate' method.
|
| 213 |
+
"""
|
| 214 |
+
model = self.unwrap_dist_model(self.model)
|
| 215 |
+
model.eval()
|
| 216 |
+
|
| 217 |
+
dataloader = getattr(self, split + "_loader", None)
|
| 218 |
+
assert dataloader is not None, f"{split}_loader does not exist."
|
| 219 |
+
|
| 220 |
+
metric_logger = MetricLogger(delimiter=" ")
|
| 221 |
+
header = f"Eval: data epoch: [{epoch}]"
|
| 222 |
+
|
| 223 |
+
results_per_task = defaultdict(list) # Store results per task
|
| 224 |
+
overall_results = [] # Store all results for overall metrics
|
| 225 |
+
|
| 226 |
+
# Calculate N based on decode_ratio
|
| 227 |
+
if decode_ratio <= 0.0:
|
| 228 |
+
N = float("inf") # Effectively never run generate
|
| 229 |
+
elif decode_ratio >= 1.0:
|
| 230 |
+
N = 1 # Run generate every batch
|
| 231 |
+
else:
|
| 232 |
+
N = max(int(1 / decode_ratio), 1) # Ensure N is at least 1
|
| 233 |
+
|
| 234 |
+
batch_idx = 0
|
| 235 |
+
|
| 236 |
+
# Initialize overall metrics
|
| 237 |
+
overall_res = {
|
| 238 |
+
"loss": torch.tensor(0.0, device=self.device),
|
| 239 |
+
"correct": torch.tensor(0.0, device=self.device),
|
| 240 |
+
"total": torch.tensor(0.0, device=self.device),
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
# Initialize per-task metrics
|
| 244 |
+
per_task_res = defaultdict(
|
| 245 |
+
lambda: {
|
| 246 |
+
"loss": torch.tensor(0.0, device=self.device),
|
| 247 |
+
"correct": torch.tensor(0.0, device=self.device),
|
| 248 |
+
"total": torch.tensor(0.0, device=self.device),
|
| 249 |
+
"n_sample": 0,
|
| 250 |
+
"predicted_texts": [],
|
| 251 |
+
"gold_texts": [],
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
for samples in metric_logger.log_every(dataloader, self.config.run.log_freq, header=header):
|
| 256 |
+
samples = prepare_sample_dist(samples, self.device)
|
| 257 |
+
|
| 258 |
+
with torch.autocast(self.device.type, enabled=self.use_amp):
|
| 259 |
+
forward_result = model(samples, verbose=True)
|
| 260 |
+
|
| 261 |
+
# Extract batch-level loss and correct counts
|
| 262 |
+
batch_loss = forward_result.get("loss", torch.tensor(0.0, device=self.device))
|
| 263 |
+
batch_correct = forward_result.get("correct", torch.tensor(0.0, device=self.device))
|
| 264 |
+
batch_total = forward_result.get("total", torch.tensor(1.0, device=self.device))
|
| 265 |
+
|
| 266 |
+
batch_size = len(samples["id"])
|
| 267 |
+
|
| 268 |
+
# Update overall metrics with batch-level values
|
| 269 |
+
overall_res["loss"] += batch_loss.detach()
|
| 270 |
+
overall_res["correct"] += batch_correct.detach()
|
| 271 |
+
overall_res["total"] += batch_total.detach()
|
| 272 |
+
|
| 273 |
+
# Decide whether to run generate based on decode_ratio
|
| 274 |
+
if decode and (batch_idx % N == 0):
|
| 275 |
+
prompts = samples.get("prompt", None)
|
| 276 |
+
try:
|
| 277 |
+
generated_texts = model.generate(samples, self.config.generate, prompts=prompts)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
print("error in generation", e)
|
| 280 |
+
generated_texts = [None] * batch_size
|
| 281 |
+
else:
|
| 282 |
+
generated_texts = [None] * batch_size # Placeholder if not decoding
|
| 283 |
+
|
| 284 |
+
# Process per-sample data for per-task metrics and result saving
|
| 285 |
+
for i in range(batch_size):
|
| 286 |
+
task = samples["task"][i]
|
| 287 |
+
|
| 288 |
+
# Collect per-task batch-level metrics
|
| 289 |
+
per_task_res[task]["loss"] += batch_loss.detach()
|
| 290 |
+
per_task_res[task]["correct"] += batch_correct.detach()
|
| 291 |
+
per_task_res[task]["total"] += batch_total.detach()
|
| 292 |
+
per_task_res[task]["n_sample"] += 1
|
| 293 |
+
|
| 294 |
+
res = {
|
| 295 |
+
"id": samples["id"][i],
|
| 296 |
+
"ground_truth": samples["text"][i], # Gold label from dataloader
|
| 297 |
+
"task": task,
|
| 298 |
+
"predicted_text": generated_texts[i],
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
if decode and generated_texts[i] is not None:
|
| 302 |
+
res["prompt"] = samples.get("prompt", [None])[i]
|
| 303 |
+
|
| 304 |
+
results_per_task[task].append(res)
|
| 305 |
+
overall_results.append(res)
|
| 306 |
+
|
| 307 |
+
# Collect texts for custom metrics
|
| 308 |
+
if generated_texts[i] is not None:
|
| 309 |
+
per_task_res[task]["predicted_texts"].append(generated_texts[i])
|
| 310 |
+
per_task_res[task]["gold_texts"].append(samples["text"][i])
|
| 311 |
+
|
| 312 |
+
batch_idx += 1 # Increment batch index
|
| 313 |
+
|
| 314 |
+
if save_json:
|
| 315 |
+
for task, task_results in results_per_task.items():
|
| 316 |
+
self.save_result(task_results, self.output_dir, f"eval_{split}_{task}_epoch_{epoch}")
|
| 317 |
+
# Optionally save overall results
|
| 318 |
+
self.save_result(overall_results, self.output_dir, f"eval_{split}_epoch_{epoch}")
|
| 319 |
+
|
| 320 |
+
# Synchronize metrics across processes if in distributed mode
|
| 321 |
+
if is_dist_avail_and_initialized():
|
| 322 |
+
for key in overall_res:
|
| 323 |
+
dist.all_reduce(overall_res[key])
|
| 324 |
+
|
| 325 |
+
overall_ret = {
|
| 326 |
+
"loss": (overall_res["loss"] / batch_idx).item(),
|
| 327 |
+
"agg_metrics": (overall_res["correct"] / overall_res["total"]).item(),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
if is_main_process():
|
| 331 |
+
# Log overall metrics
|
| 332 |
+
wandb.log(
|
| 333 |
+
{
|
| 334 |
+
f"{split}_loss": overall_ret["loss"],
|
| 335 |
+
f"{split}_accuracy": overall_ret["agg_metrics"],
|
| 336 |
+
"epoch": epoch,
|
| 337 |
+
}
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Compute and log per-task metrics
|
| 341 |
+
for task, res in per_task_res.items():
|
| 342 |
+
if "caption-none" in task:
|
| 343 |
+
continue
|
| 344 |
+
|
| 345 |
+
if self.use_distributed:
|
| 346 |
+
print(f"Rank {dist.get_rank()}, task={task}, ")
|
| 347 |
+
|
| 348 |
+
print(
|
| 349 |
+
f"loss={res['loss'].shape, res['loss'].dtype}, "
|
| 350 |
+
f"correct={res['correct'].shape, res['correct'].dtype}, "
|
| 351 |
+
f"total={res['total'].shape, res['total'].dtype}, "
|
| 352 |
+
f"n_sample={res['n_sample']}"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Synchronize metrics across processes if in distributed mode
|
| 356 |
+
if is_dist_avail_and_initialized():
|
| 357 |
+
dist.all_reduce(res["loss"])
|
| 358 |
+
dist.all_reduce(res["correct"])
|
| 359 |
+
dist.all_reduce(res["total"])
|
| 360 |
+
dist.all_reduce(torch.tensor(res["n_sample"], device=self.device))
|
| 361 |
+
|
| 362 |
+
ret = {
|
| 363 |
+
"loss": (res["loss"] / res["n_sample"]).item(),
|
| 364 |
+
"agg_metrics": (res["correct"] / res["total"]).item(),
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
if is_main_process():
|
| 368 |
+
# Log per-task metrics
|
| 369 |
+
wandb.log(
|
| 370 |
+
{
|
| 371 |
+
f"{split}_{task}_loss": ret["loss"],
|
| 372 |
+
f"{split}_{task}_accuracy": ret["agg_metrics"],
|
| 373 |
+
"epoch": epoch,
|
| 374 |
+
}
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Get and compute custom metrics for this task
|
| 378 |
+
metrics_list = get_task_metrics(task)
|
| 379 |
+
predicted_texts = res["predicted_texts"]
|
| 380 |
+
gold_texts = res["gold_texts"]
|
| 381 |
+
for metric in metrics_list:
|
| 382 |
+
if predicted_texts and gold_texts:
|
| 383 |
+
metric_value = metric.compute_metric(predicted_texts, gold_texts)
|
| 384 |
+
metric_name = metric.__class__.__name__
|
| 385 |
+
wandb.log(
|
| 386 |
+
{
|
| 387 |
+
f"{split}_{task}_{metric_name}": metric_value,
|
| 388 |
+
"epoch": epoch,
|
| 389 |
+
}
|
| 390 |
+
)
|
| 391 |
+
return overall_ret # Return overall metrics
|
| 392 |
+
|
| 393 |
+
def save_result(self, result, result_dir, filename):
|
| 394 |
+
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, get_rank()))
|
| 395 |
+
final_result_file = os.path.join(result_dir, "%s.json" % filename)
|
| 396 |
+
|
| 397 |
+
try:
|
| 398 |
+
json.dump(result, open(result_file, "w"), ensure_ascii=False)
|
| 399 |
+
except Exception as e:
|
| 400 |
+
logging.warning(f"Error saving {result_file}. Error: {e}")
|
| 401 |
+
json.dump(result, open(result_file, "w", encoding="utf-8"), ensure_ascii=False)
|
| 402 |
+
|
| 403 |
+
# if is_dist_avail_and_initialized():
|
| 404 |
+
# dist.barrier()
|
| 405 |
+
|
| 406 |
+
if is_main_process():
|
| 407 |
+
logging.info("rank %d starts merging results." % get_rank())
|
| 408 |
+
result = []
|
| 409 |
+
|
| 410 |
+
for rank in range(get_world_size()):
|
| 411 |
+
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
|
| 412 |
+
try:
|
| 413 |
+
res = json.load(open(result_file, "r"))
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logging.warning(f"Error reading {result_file}. Error: {e}")
|
| 416 |
+
res = json.load(open(result_file, "r", encoding="utf-8"))
|
| 417 |
+
result += res
|
| 418 |
+
|
| 419 |
+
try:
|
| 420 |
+
json.dump(result, open(final_result_file, "w"), ensure_ascii=False)
|
| 421 |
+
except Exception as e:
|
| 422 |
+
logging.warning(f"Error saving {final_result_file}. Error: {e}")
|
| 423 |
+
json.dump(
|
| 424 |
+
result,
|
| 425 |
+
open(final_result_file, "w", encoding="utf-8"),
|
| 426 |
+
ensure_ascii=False,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
print("result file saved to %s" % final_result_file)
|
| 430 |
+
|
| 431 |
+
def train(self):
|
| 432 |
+
start_time = time.time()
|
| 433 |
+
best_agg_metric = 0
|
| 434 |
+
best_epoch = 0
|
| 435 |
+
|
| 436 |
+
for cur_epoch in range(self.start_epoch, self.max_epoch):
|
| 437 |
+
if self.evaluate_only:
|
| 438 |
+
break
|
| 439 |
+
|
| 440 |
+
# training phase
|
| 441 |
+
logging.info("Training Phase")
|
| 442 |
+
train_stats = self.train_epoch(cur_epoch)
|
| 443 |
+
self.log_stats(train_stats, split_name="train")
|
| 444 |
+
|
| 445 |
+
# validating phase
|
| 446 |
+
logging.info("Validating Phase")
|
| 447 |
+
valid_log = self.valid_epoch(
|
| 448 |
+
cur_epoch,
|
| 449 |
+
"valid",
|
| 450 |
+
decode=self.config.run.custom_metrics,
|
| 451 |
+
save_json=False,
|
| 452 |
+
decode_ratio=self.config.run.decode_ratio,
|
| 453 |
+
)
|
| 454 |
+
if valid_log is not None:
|
| 455 |
+
if is_main_process():
|
| 456 |
+
agg_metrics = valid_log["agg_metrics"]
|
| 457 |
+
if agg_metrics > best_agg_metric:
|
| 458 |
+
best_agg_metric = agg_metrics
|
| 459 |
+
best_epoch = cur_epoch
|
| 460 |
+
self.save_checkpoint(cur_epoch, is_best=True)
|
| 461 |
+
|
| 462 |
+
valid_log.update({"best_epoch": best_epoch})
|
| 463 |
+
self.log_stats(valid_log, split_name="valid")
|
| 464 |
+
self.save_checkpoint(cur_epoch, is_best=False)
|
| 465 |
+
|
| 466 |
+
# if self.use_distributed:
|
| 467 |
+
# dist.barrier()
|
| 468 |
+
|
| 469 |
+
# testing phase
|
| 470 |
+
if self.evaluate_only:
|
| 471 |
+
self.valid_epoch("best", "test", decode=True, save_json=True)
|
| 472 |
+
|
| 473 |
+
total_time = time.time() - start_time
|
| 474 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 475 |
+
logging.info("Training time {}".format(total_time_str))
|
| 476 |
+
|
| 477 |
+
@main_process
|
| 478 |
+
def log_config(self):
|
| 479 |
+
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
|
| 480 |
+
f.write(json.dumps(self.config.model_dump_json(), indent=4) + "\n")
|
| 481 |
+
|
| 482 |
+
@main_process
|
| 483 |
+
def log_stats(self, stats, split_name):
|
| 484 |
+
if isinstance(stats, dict):
|
| 485 |
+
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
|
| 486 |
+
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
|
| 487 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 488 |
+
elif isinstance(stats, list):
|
| 489 |
+
pass
|
| 490 |
+
|
| 491 |
+
@main_process
|
| 492 |
+
def save_checkpoint(self, cur_epoch, is_best=False):
|
| 493 |
+
"""
|
| 494 |
+
Save the checkpoint at the current epoch.
|
| 495 |
+
"""
|
| 496 |
+
model_no_ddp = self.unwrap_dist_model(self.model)
|
| 497 |
+
param_grad_dic = {k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()}
|
| 498 |
+
state_dict = model_no_ddp.state_dict()
|
| 499 |
+
for k in list(state_dict.keys()):
|
| 500 |
+
if k in param_grad_dic.keys() and not param_grad_dic[k]:
|
| 501 |
+
# delete parameters that do not require gradient
|
| 502 |
+
del state_dict[k]
|
| 503 |
+
save_obj = {
|
| 504 |
+
"model": state_dict,
|
| 505 |
+
"optimizer": self.optimizer.state_dict(),
|
| 506 |
+
"config": dict(self.config),
|
| 507 |
+
"scaler": self.scaler.state_dict() if self.scaler else None,
|
| 508 |
+
"epoch": cur_epoch,
|
| 509 |
+
}
|
| 510 |
+
save_to = os.path.join(
|
| 511 |
+
self.output_dir,
|
| 512 |
+
"checkpoint_{}.pth".format("best" if is_best else cur_epoch),
|
| 513 |
+
)
|
| 514 |
+
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
|
| 515 |
+
torch.save(save_obj, save_to)
|
NatureLM/storage_utils.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import cloudpathlib
|
| 7 |
+
from google.cloud.storage.client import Client
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def is_gcs_path(path: Union[str, os.PathLike]) -> bool:
|
| 13 |
+
return str(path).startswith("gs://")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@lru_cache(maxsize=1)
|
| 17 |
+
def _get_client():
|
| 18 |
+
return cloudpathlib.GSClient(storage_client=Client())
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
_gcp_storage_client = _get_client()
|
| 23 |
+
except Exception:
|
| 24 |
+
logger.warning("Failed to initialize GCS client." "Training wont be able to use GSPath or R2Path without a client.")
|
| 25 |
+
_gcp_storage_client = None
|
| 26 |
+
|
NatureLM/task_metric_utils.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Taken from DCASE 2021 Task 5 evaluation source code
|
| 2 |
+
# https://github.com/c4dm/dcase-few-shot-bioacoustic
|
| 3 |
+
# MIT License
|
| 4 |
+
|
| 5 |
+
import mir_eval
|
| 6 |
+
import numpy as np
|
| 7 |
+
import scipy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fast_intersect(ref, est):
|
| 11 |
+
"""Find all intersections between reference events and estimated events (fast).
|
| 12 |
+
Best-case complexity: O(N log N + M log M) where N=length(ref) and M=length(est)
|
| 13 |
+
Parameters
|
| 14 |
+
----------
|
| 15 |
+
ref: np.ndarray [shape=(2, n)], real-valued
|
| 16 |
+
Array of reference events. Each column is an event.
|
| 17 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 18 |
+
est: np.ndarray [shape=(2, m)], real-valued
|
| 19 |
+
Array of estimated events. Each column is an event.
|
| 20 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 21 |
+
Returns
|
| 22 |
+
-------
|
| 23 |
+
matches: list of sets, length n, integer-valued
|
| 24 |
+
Property: matches[i] contains the set of all indices j such that
|
| 25 |
+
(ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
|
| 26 |
+
"""
|
| 27 |
+
ref_on_argsort = np.argsort(ref[0, :])
|
| 28 |
+
ref_off_argsort = np.argsort(ref[1, :])
|
| 29 |
+
|
| 30 |
+
est_on_argsort = np.argsort(est[0, :])
|
| 31 |
+
est_off_argsort = np.argsort(est[1, :])
|
| 32 |
+
|
| 33 |
+
est_on_maxindex = est.shape[1]
|
| 34 |
+
est_off_minindex = 0
|
| 35 |
+
estref_matches = [set()] * ref.shape[1]
|
| 36 |
+
refest_matches = [set()] * ref.shape[1]
|
| 37 |
+
for ref_id in range(ref.shape[1]):
|
| 38 |
+
ref_onset = ref[0, ref_on_argsort[ref_id]]
|
| 39 |
+
est_off_sorted = est[1, est_off_argsort[est_off_minindex:]]
|
| 40 |
+
search_result = np.searchsorted(est_off_sorted, ref_onset, side="left")
|
| 41 |
+
est_off_minindex += search_result
|
| 42 |
+
refest_match = est_off_argsort[est_off_minindex:]
|
| 43 |
+
refest_matches[ref_on_argsort[ref_id]] = set(refest_match)
|
| 44 |
+
|
| 45 |
+
ref_offset = ref[1, ref_off_argsort[-1 - ref_id]]
|
| 46 |
+
est_on_sorted = est[0, est_on_argsort[: (1 + est_on_maxindex)]]
|
| 47 |
+
search_result = np.searchsorted(est_on_sorted, ref_offset, side="right")
|
| 48 |
+
est_on_maxindex = search_result - 1
|
| 49 |
+
estref_match = est_on_argsort[: (1 + est_on_maxindex)]
|
| 50 |
+
estref_matches[ref_off_argsort[-1 - ref_id]] = set(estref_match)
|
| 51 |
+
|
| 52 |
+
zip_iterator = zip(refest_matches, estref_matches)
|
| 53 |
+
matches = [x.intersection(y) for (x, y) in zip_iterator]
|
| 54 |
+
return matches
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def iou(ref, est, method="fast"):
|
| 58 |
+
"""Compute pairwise "intersection over union" (IOU) metric between reference
|
| 59 |
+
events and estimated events.
|
| 60 |
+
Let us denote by a_i and b_i the onset and offset of reference event i.
|
| 61 |
+
Let us denote by u_j and v_j the onset and offset of estimated event j.
|
| 62 |
+
The IOU between events i and j is defined as
|
| 63 |
+
(min(b_i, v_j)-max(a_i, u_j)) / (max(b_i, v_j)-min(a_i, u_j))
|
| 64 |
+
if the events are non-disjoint, and equal to zero otherwise.
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
ref: np.ndarray [shape=(2, n)], real-valued
|
| 68 |
+
Array of reference events. Each column is an event.
|
| 69 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 70 |
+
est: np.ndarray [shape=(2, m)], real-valued
|
| 71 |
+
Array of estimated events. Each column is an event.
|
| 72 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 73 |
+
method: str, optional.
|
| 74 |
+
If "fast" (default), computes pairwise intersections via a custom
|
| 75 |
+
dynamic programming algorithm, see fast_intersect.
|
| 76 |
+
If "slow", computes pairwise intersections via bruteforce quadratic
|
| 77 |
+
search, see slow_intersect.
|
| 78 |
+
Returns
|
| 79 |
+
-------
|
| 80 |
+
S: scipy.sparse.dok.dok_matrix, real-valued
|
| 81 |
+
Sparse 2-D matrix. S[i,j] contains the IOU between ref[i] and est[j]
|
| 82 |
+
if these events are non-disjoint and zero otherwise.
|
| 83 |
+
"""
|
| 84 |
+
n_refs = ref.shape[1]
|
| 85 |
+
n_ests = est.shape[1]
|
| 86 |
+
S = scipy.sparse.dok_matrix((n_refs, n_ests))
|
| 87 |
+
|
| 88 |
+
if method == "fast":
|
| 89 |
+
matches = fast_intersect(ref, est)
|
| 90 |
+
elif method == "slow":
|
| 91 |
+
matches = slow_intersect(ref, est)
|
| 92 |
+
|
| 93 |
+
for ref_id in range(n_refs):
|
| 94 |
+
matching_ests = matches[ref_id]
|
| 95 |
+
ref_on = ref[0, ref_id]
|
| 96 |
+
ref_off = ref[1, ref_id]
|
| 97 |
+
|
| 98 |
+
for matching_est_id in matching_ests:
|
| 99 |
+
est_on = est[0, matching_est_id]
|
| 100 |
+
est_off = est[1, matching_est_id]
|
| 101 |
+
intersection = min(ref_off, est_off) - max(ref_on, est_on)
|
| 102 |
+
union = max(ref_off, est_off) - min(ref_on, est_on)
|
| 103 |
+
intersection_over_union = intersection / union
|
| 104 |
+
S[ref_id, matching_est_id] = intersection_over_union
|
| 105 |
+
|
| 106 |
+
return S
|
| 107 |
+
|
| 108 |
+
def compute_intersection(ref, est, method="fast"):
|
| 109 |
+
"""Compute pairwise intersection between reference
|
| 110 |
+
events and estimated events.
|
| 111 |
+
Let us denote by a_i and b_i the onset and offset of reference event i.
|
| 112 |
+
Let us denote by u_j and v_j the onset and offset of estimated event j.
|
| 113 |
+
The Intersection between events i and j is defined as
|
| 114 |
+
(min(b_i, v_j)-max(a_i, u_j))
|
| 115 |
+
if the events are non-disjoint, and equal to zero otherwise.
|
| 116 |
+
Parameters
|
| 117 |
+
----------
|
| 118 |
+
ref: np.ndarray [shape=(2, n)], real-valued
|
| 119 |
+
Array of reference events. Each column is an event.
|
| 120 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 121 |
+
est: np.ndarray [shape=(2, m)], real-valued
|
| 122 |
+
Array of estimated events. Each column is an event.
|
| 123 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 124 |
+
method: str, optional.
|
| 125 |
+
If "fast" (default), computes pairwise intersections via a custom
|
| 126 |
+
dynamic programming algorithm, see fast_intersect.
|
| 127 |
+
If "slow", computes pairwise intersections via bruteforce quadratic
|
| 128 |
+
search, see slow_intersect.
|
| 129 |
+
Returns
|
| 130 |
+
-------
|
| 131 |
+
S: scipy.sparse.dok.dok_matrix, real-valued
|
| 132 |
+
Sparse 2-D matrix. S[i,j] contains the Intersection between ref[i] and est[j]
|
| 133 |
+
if these events are non-disjoint and zero otherwise.
|
| 134 |
+
"""
|
| 135 |
+
n_refs = ref.shape[1]
|
| 136 |
+
n_ests = est.shape[1]
|
| 137 |
+
S = scipy.sparse.dok_matrix((n_refs, n_ests))
|
| 138 |
+
|
| 139 |
+
if method == "fast":
|
| 140 |
+
matches = fast_intersect(ref, est)
|
| 141 |
+
elif method == "slow":
|
| 142 |
+
matches = slow_intersect(ref, est)
|
| 143 |
+
|
| 144 |
+
for ref_id in range(n_refs):
|
| 145 |
+
matching_ests = matches[ref_id]
|
| 146 |
+
ref_on = ref[0, ref_id]
|
| 147 |
+
ref_off = ref[1, ref_id]
|
| 148 |
+
|
| 149 |
+
for matching_est_id in matching_ests:
|
| 150 |
+
est_on = est[0, matching_est_id]
|
| 151 |
+
est_off = est[1, matching_est_id]
|
| 152 |
+
intersection = min(ref_off, est_off) - max(ref_on, est_on)
|
| 153 |
+
# union = max(ref_off, est_off) - min(ref_on, est_on)
|
| 154 |
+
# intersection_over_union = intersection / union
|
| 155 |
+
S[ref_id, matching_est_id] = intersection #_over_union
|
| 156 |
+
|
| 157 |
+
return S
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def match_events(ref, est, min_iou=0.0, method="fast"):
|
| 161 |
+
"""
|
| 162 |
+
Compute a maximum matching between reference and estimated event times,
|
| 163 |
+
subject to a criterion of minimum intersection-over-union (IOU).
|
| 164 |
+
Given two lists of events ``ref`` (reference) and ``est`` (estimated),
|
| 165 |
+
we seek the largest set of correspondences ``(ref[i], est[j])`` such that
|
| 166 |
+
``iou(ref[i], est[j]) <= min_iou``
|
| 167 |
+
and such that each ``ref[i]`` and ``est[j]`` is matched at most once.
|
| 168 |
+
This function is strongly inspired by mir_eval.onset.util.match_events.
|
| 169 |
+
It relies on mir_eval's implementation of the Hopcroft-Karp algorithm from
|
| 170 |
+
maximum bipartite graph matching. However, one important difference is that
|
| 171 |
+
mir_eval's distance function relies purely on onset times, whereas this function
|
| 172 |
+
considers both onset times and offset times to compute the IOU metric between
|
| 173 |
+
reference events and estimated events.
|
| 174 |
+
Parameters
|
| 175 |
+
----------
|
| 176 |
+
ref: np.ndarray [shape=(2, n)], real-valued
|
| 177 |
+
Array of reference events. Each column is an event.
|
| 178 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 179 |
+
est: np.ndarray [shape=(2, m)], real-valued
|
| 180 |
+
Array of estimated events. Each column is an event.
|
| 181 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 182 |
+
min_iou: real number in [0, 1). Default: 0.
|
| 183 |
+
Threshold for minimum amount of intersection over union (IOU) to match
|
| 184 |
+
any two events. See the iou method for implementation details.
|
| 185 |
+
method: str, optional.
|
| 186 |
+
If "fast" (default), computes pairwise intersections via a custom
|
| 187 |
+
dynamic programming algorithm, see fast_intersect.
|
| 188 |
+
If "slow", computes pairwise intersections via bruteforce quadratic
|
| 189 |
+
search, see slow_intersect.
|
| 190 |
+
Returns
|
| 191 |
+
-------
|
| 192 |
+
matching : list of tuples
|
| 193 |
+
Every tuple corresponds to a match between one reference event and
|
| 194 |
+
one estimated event.
|
| 195 |
+
``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``.
|
| 196 |
+
Note that all values i and j appear at most once in the list.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
# Intersect reference events and estimated events
|
| 200 |
+
S = iou(ref, est, method=method)
|
| 201 |
+
|
| 202 |
+
# Threshold intersection-over-union (IOU) ratio
|
| 203 |
+
S_bool = scipy.sparse.dok_matrix(S > min_iou)
|
| 204 |
+
hits = S_bool.keys()
|
| 205 |
+
|
| 206 |
+
# Construct the bipartite graph
|
| 207 |
+
G = {}
|
| 208 |
+
for ref_i, est_i in hits:
|
| 209 |
+
if est_i not in G:
|
| 210 |
+
G[est_i] = []
|
| 211 |
+
G[est_i].append(ref_i)
|
| 212 |
+
|
| 213 |
+
# Apply Hopcroft-Karp algorithm (from mir_eval package)
|
| 214 |
+
# to obtain maximum bipartite graph matching
|
| 215 |
+
matching = sorted(mir_eval.util._bipartite_match(G).items())
|
| 216 |
+
return matching
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def slow_intersect(ref, est):
|
| 220 |
+
"""Find all intersections between reference events and estimated events (slow).
|
| 221 |
+
Best-case complexity: O(N*M) where N=ref.shape[1] and M=est.shape[1]
|
| 222 |
+
Parameters
|
| 223 |
+
----------
|
| 224 |
+
ref: np.ndarray [shape=(2, n)], real-valued
|
| 225 |
+
Array of reference events. Each column is an event.
|
| 226 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 227 |
+
est: np.ndarray [shape=(2, m)], real-valued
|
| 228 |
+
Array of estimated events. Each column is an event.
|
| 229 |
+
The first row denotes onset times and the second row denotes offset times.
|
| 230 |
+
Returns
|
| 231 |
+
-------
|
| 232 |
+
matches: list of sets, length n, integer-valued
|
| 233 |
+
Property: matches[i] contains the set of all indices j such that
|
| 234 |
+
(ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
|
| 235 |
+
"""
|
| 236 |
+
matches = []
|
| 237 |
+
for i in range(ref.shape[1]):
|
| 238 |
+
matches.append(
|
| 239 |
+
set(
|
| 240 |
+
[
|
| 241 |
+
j
|
| 242 |
+
for j in range(est.shape[1])
|
| 243 |
+
if ((ref[0, i] <= est[1, j]) and (ref[1, i] >= est[0, j]))
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
return
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def frames_to_st_dict(x, sr=16000):
|
| 251 |
+
# x : Tensor of shape (batch, time) or (time,). Entries are 2 (POS), 1 (UNK), and 0 (NEG).
|
| 252 |
+
# returns a list of dicts {"Begin Time (s)" : [...], "End Time (s)" : [...], "Annotation" : [...]} if batch dim exists, or a single dict
|
| 253 |
+
|
| 254 |
+
if len(x.size()) == 2:
|
| 255 |
+
outs = []
|
| 256 |
+
for i in range(x.size(0)):
|
| 257 |
+
x_sub = x[i,:]
|
| 258 |
+
outs.append(_frames_to_st_dict_single(x_sub, sr=sr))
|
| 259 |
+
return outs
|
| 260 |
+
else:
|
| 261 |
+
return _frames_to_st_dict_single(x, sr=sr)
|
| 262 |
+
|
| 263 |
+
def _frames_to_st_dict_single(x, sr=16000):
|
| 264 |
+
d = {"Begin Time (s)" : [], "End Time (s)" : [], "Annotation" : []}
|
| 265 |
+
|
| 266 |
+
for label_i in [1,2]:
|
| 267 |
+
|
| 268 |
+
labels = x.numpy() == label_i # POS : 2, UNK : 1, NEG : 0
|
| 269 |
+
|
| 270 |
+
starts = np.where((~labels[:-1]) & (labels[1:]))[0] + 1
|
| 271 |
+
if labels[0]:
|
| 272 |
+
starts = np.insert(starts, 0, 0)
|
| 273 |
+
|
| 274 |
+
ends = np.where((labels[:-1]) & (~labels[1:]))[0] + 1
|
| 275 |
+
if labels[-1]:
|
| 276 |
+
ends = np.append(ends, len(labels))
|
| 277 |
+
|
| 278 |
+
for start, end in zip(starts, ends):
|
| 279 |
+
d["Begin Time (s)"].append(start/sr)
|
| 280 |
+
d["End Time (s)"].append(end/sr)
|
| 281 |
+
d["Annotation"].append("POS" if label_i == 2 else "UNK")
|
| 282 |
+
|
| 283 |
+
return d
|
NatureLM/task_metrics.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from NatureLM.task_metric_utils import match_events
|
| 8 |
+
|
| 9 |
+
# Assume the following functions are imported from the reference implementations:
|
| 10 |
+
# - match_events
|
| 11 |
+
# - iou
|
| 12 |
+
# - fast_intersect
|
| 13 |
+
# - slow_intersect
|
| 14 |
+
# - compute_intersection
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Metric(ABC):
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ExactAccuracy(Metric):
|
| 24 |
+
"""Exact-match accuracy metric."""
|
| 25 |
+
|
| 26 |
+
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 27 |
+
predicted_texts = [pt.lower().strip() for pt in predicted_texts]
|
| 28 |
+
gold_texts = [gt.lower().strip() for gt in gold_texts]
|
| 29 |
+
correct = sum(p == g for p, g in zip(predicted_texts, gold_texts))
|
| 30 |
+
return correct / len(gold_texts) if gold_texts else 0.0
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FewShot(Metric):
|
| 34 |
+
"""Few-shot learning metric based on event matching using IoU."""
|
| 35 |
+
|
| 36 |
+
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 37 |
+
# Initialize counts
|
| 38 |
+
total_TP = 0
|
| 39 |
+
total_FP = 0
|
| 40 |
+
total_FN = 0
|
| 41 |
+
|
| 42 |
+
for pred_text, gold_text in zip(predicted_texts, gold_texts):
|
| 43 |
+
# Extract events from texts
|
| 44 |
+
pred_events = parse_timestamps_from_text(pred_text)
|
| 45 |
+
gold_events = parse_timestamps_from_text(gold_text)
|
| 46 |
+
|
| 47 |
+
# Convert events to numpy arrays for match_events function
|
| 48 |
+
# Each event is (start_time, end_time), need to transpose to shape (2, n)
|
| 49 |
+
pred_array = np.array(pred_events).T if pred_events else np.empty((2, 0))
|
| 50 |
+
gold_array = np.array(gold_events).T if gold_events else np.empty((2, 0))
|
| 51 |
+
|
| 52 |
+
# Use match_events function from the reference implementation
|
| 53 |
+
matches = match_events(gold_array, pred_array, min_iou=0.5, method="fast")
|
| 54 |
+
|
| 55 |
+
TP = len(matches)
|
| 56 |
+
FP = len(pred_events) - TP
|
| 57 |
+
FN = len(gold_events) - TP
|
| 58 |
+
|
| 59 |
+
total_TP += TP
|
| 60 |
+
total_FP += FP
|
| 61 |
+
total_FN += FN
|
| 62 |
+
|
| 63 |
+
# Compute precision, recall, and F1 score
|
| 64 |
+
precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
|
| 65 |
+
recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0
|
| 66 |
+
f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 67 |
+
|
| 68 |
+
return f1_score
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class NoneAccuracy(Metric):
|
| 72 |
+
"""Accuracy for cases where 'None' is the correct answer."""
|
| 73 |
+
|
| 74 |
+
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 75 |
+
# Normalize texts
|
| 76 |
+
predicted_texts = [pt.lower().strip() for pt in predicted_texts]
|
| 77 |
+
gold_texts = [gt.lower().strip() for gt in gold_texts]
|
| 78 |
+
# Filter indices where gold_text is 'none'
|
| 79 |
+
indices = [i for i, gt in enumerate(gold_texts) if gt == "none"]
|
| 80 |
+
if not indices:
|
| 81 |
+
return 0.0 # No 'None' cases in gold_texts
|
| 82 |
+
correct = sum(predicted_texts[i] == "none" for i in indices)
|
| 83 |
+
return correct / len(indices)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MultipleSpeciesAccuracy(Metric):
|
| 87 |
+
"""Accuracy for cases where the correct answer has at least one comma (multiple species)."""
|
| 88 |
+
|
| 89 |
+
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 90 |
+
# Normalize texts
|
| 91 |
+
predicted_texts = [pt.lower().strip() for pt in predicted_texts]
|
| 92 |
+
gold_texts = [gt.lower().strip() for gt in gold_texts]
|
| 93 |
+
# Filter indices where gold_text contains at least one comma
|
| 94 |
+
indices = [i for i, gt in enumerate(gold_texts) if "," in gt]
|
| 95 |
+
if not indices:
|
| 96 |
+
return 0.0 # No multiple-species cases in gold_texts
|
| 97 |
+
correct = sum(predicted_texts[i] == gold_texts[i] for i in indices)
|
| 98 |
+
return correct / len(indices)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_task_metrics(task: str) -> List[Metric]:
|
| 102 |
+
"""Get a list of metric instances appropriate for the given task."""
|
| 103 |
+
all_metrics = []
|
| 104 |
+
metrics_dict = {}
|
| 105 |
+
|
| 106 |
+
if "classification" in task:
|
| 107 |
+
metrics_dict["ExactAccuracy"] = ExactAccuracy()
|
| 108 |
+
if "fewshot" in task:
|
| 109 |
+
metrics_dict["FewShot"] = FewShot()
|
| 110 |
+
if "detection" in task:
|
| 111 |
+
metrics_dict["ExactAccuracy"] = ExactAccuracy() # Ensures no duplicate
|
| 112 |
+
metrics_dict["NoneAccuracy"] = NoneAccuracy()
|
| 113 |
+
metrics_dict["MultipleSpeciesAccuracy"] = MultipleSpeciesAccuracy()
|
| 114 |
+
|
| 115 |
+
all_metrics = list(metrics_dict.values())
|
| 116 |
+
return all_metrics
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def parse_timestamps_from_text(text: str) -> List[Tuple[float, float]]:
|
| 120 |
+
"""
|
| 121 |
+
Function to parse timestamps from text.
|
| 122 |
+
Extracts timestamps in the format "start-end" where start and end are floats.
|
| 123 |
+
"""
|
| 124 |
+
# Regular expression to extract timestamps in the format "start-end"
|
| 125 |
+
pattern = r"(\d+\.\d+)-(\d+\.\d+)"
|
| 126 |
+
matches = re.findall(pattern, text)
|
| 127 |
+
events = [(float(start), float(end)) for start, end in matches]
|
| 128 |
+
return events
|
NatureLM/utils.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Earth Species Project
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import time
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Literal
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import resampy
|
| 24 |
+
import soundfile as sf
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import torchaudio
|
| 28 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 29 |
+
|
| 30 |
+
from NatureLM.dist_utils import get_rank, get_world_size
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
TARGET_SAMPLE_RATE = 16_000
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def snr_scale(clean, noise, snr):
|
| 39 |
+
# Ensure both clean and noise have the same length
|
| 40 |
+
assert clean.shape == noise.shape, "Clean and noise must have the same shape."
|
| 41 |
+
|
| 42 |
+
# Compute power (mean squared amplitude)
|
| 43 |
+
power_signal = torch.mean(clean**2)
|
| 44 |
+
power_noise = torch.mean(noise**2)
|
| 45 |
+
|
| 46 |
+
# Prevent division by zero
|
| 47 |
+
epsilon = 1e-10
|
| 48 |
+
power_noise = torch.clamp(power_noise, min=epsilon)
|
| 49 |
+
|
| 50 |
+
# Calculate desired noise power based on SNR
|
| 51 |
+
desired_noise_power = power_signal / (10 ** (snr / 10))
|
| 52 |
+
|
| 53 |
+
# Scale noise to achieve the desired noise power
|
| 54 |
+
scale = torch.sqrt(desired_noise_power / power_noise)
|
| 55 |
+
scaled_noise = scale * noise
|
| 56 |
+
|
| 57 |
+
return scaled_noise
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def time_scale(signal, scale=2.0, rngnp=None, seed=42):
|
| 61 |
+
if rngnp is None:
|
| 62 |
+
rngnp = np.random.default_rng(seed=seed)
|
| 63 |
+
scaling = np.power(scale, rngnp.uniform(-1, 1))
|
| 64 |
+
output_size = int(signal.shape[-1] * scaling)
|
| 65 |
+
ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
|
| 66 |
+
ref1 = ref.clone().type(torch.int64)
|
| 67 |
+
ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64))
|
| 68 |
+
r = ref - ref1.type(ref.type())
|
| 69 |
+
scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
|
| 70 |
+
|
| 71 |
+
## trim or zero pad to torche original size
|
| 72 |
+
if scaled_signal.shape[-1] > signal.shape[-1]:
|
| 73 |
+
nframes_offset = (scaled_signal.shape[-1] - signal.shape[-1]) // 2
|
| 74 |
+
scaled_signal = scaled_signal[..., nframes_offset : nframes_offset + signal.shape[-1]]
|
| 75 |
+
else:
|
| 76 |
+
nframes_diff = signal.shape[-1] - scaled_signal.shape[-1]
|
| 77 |
+
pad_left = int(np.random.uniform() * nframes_diff)
|
| 78 |
+
pad_right = nframes_diff - pad_left
|
| 79 |
+
scaled_signal = F.pad(input=scaled_signal, pad=(pad_left, pad_right), mode="constant", value=0)
|
| 80 |
+
return scaled_signal
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def mel_frequencies(n_mels, fmin, fmax):
|
| 84 |
+
def _hz_to_mel(f):
|
| 85 |
+
return 2595 * np.log10(1 + f / 700)
|
| 86 |
+
|
| 87 |
+
def _mel_to_hz(m):
|
| 88 |
+
return 700 * (10 ** (m / 2595) - 1)
|
| 89 |
+
|
| 90 |
+
low = _hz_to_mel(fmin)
|
| 91 |
+
high = _hz_to_mel(fmax)
|
| 92 |
+
|
| 93 |
+
mels = np.linspace(low, high, n_mels)
|
| 94 |
+
|
| 95 |
+
return _mel_to_hz(mels)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def now_as_str() -> str:
|
| 99 |
+
return datetime.now().strftime("%Y%m%d%H%M")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_dataloader(dataset, config, is_train=True, use_distributed=True):
|
| 103 |
+
if use_distributed:
|
| 104 |
+
sampler = DistributedSampler(dataset, shuffle=is_train, num_replicas=get_world_size(), rank=get_rank())
|
| 105 |
+
else:
|
| 106 |
+
sampler = None
|
| 107 |
+
|
| 108 |
+
loader = DataLoader(
|
| 109 |
+
dataset,
|
| 110 |
+
batch_size=config.batch_size_train if is_train else config.batch_size_eval,
|
| 111 |
+
num_workers=config.num_workers,
|
| 112 |
+
pin_memory=False,
|
| 113 |
+
sampler=sampler,
|
| 114 |
+
shuffle=sampler is None and is_train,
|
| 115 |
+
collate_fn=dataset.collater,
|
| 116 |
+
drop_last=is_train,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if is_train:
|
| 120 |
+
loader = IterLoader(loader, use_distributed=use_distributed)
|
| 121 |
+
|
| 122 |
+
return loader
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def apply_to_sample(f, sample):
|
| 126 |
+
if len(sample) == 0:
|
| 127 |
+
return {}
|
| 128 |
+
|
| 129 |
+
def _apply(x):
|
| 130 |
+
if torch.is_tensor(x):
|
| 131 |
+
return f(x)
|
| 132 |
+
elif isinstance(x, dict):
|
| 133 |
+
return {key: _apply(value) for key, value in x.items()}
|
| 134 |
+
elif isinstance(x, list):
|
| 135 |
+
return [_apply(x) for x in x]
|
| 136 |
+
else:
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
return _apply(sample)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def move_to_device(sample, device):
|
| 143 |
+
def _move_to_device(tensor):
|
| 144 |
+
return tensor.to(device)
|
| 145 |
+
|
| 146 |
+
return apply_to_sample(_move_to_device, sample)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def prepare_sample(samples, cuda_enabled=True):
|
| 150 |
+
if cuda_enabled:
|
| 151 |
+
samples = move_to_device(samples, "cuda")
|
| 152 |
+
|
| 153 |
+
# TODO fp16 support
|
| 154 |
+
|
| 155 |
+
return samples
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def prepare_sample_dist(samples, device):
|
| 159 |
+
samples = move_to_device(samples, device)
|
| 160 |
+
|
| 161 |
+
# TODO fp16 support
|
| 162 |
+
|
| 163 |
+
return samples
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class IterLoader:
|
| 167 |
+
"""
|
| 168 |
+
A wrapper to convert DataLoader as an infinite iterator.
|
| 169 |
+
|
| 170 |
+
Modified from:
|
| 171 |
+
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
| 175 |
+
self._dataloader = dataloader
|
| 176 |
+
self.iter_loader = iter(self._dataloader)
|
| 177 |
+
self._use_distributed = use_distributed
|
| 178 |
+
self._epoch = 0
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def epoch(self) -> int:
|
| 182 |
+
return self._epoch
|
| 183 |
+
|
| 184 |
+
def __next__(self):
|
| 185 |
+
try:
|
| 186 |
+
data = next(self.iter_loader)
|
| 187 |
+
except StopIteration:
|
| 188 |
+
self._epoch += 1
|
| 189 |
+
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
| 190 |
+
self._dataloader.sampler.set_epoch(self._epoch)
|
| 191 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
| 192 |
+
self.iter_loader = iter(self._dataloader)
|
| 193 |
+
data = next(self.iter_loader)
|
| 194 |
+
|
| 195 |
+
return data
|
| 196 |
+
|
| 197 |
+
def __iter__(self):
|
| 198 |
+
return self
|
| 199 |
+
|
| 200 |
+
def __len__(self):
|
| 201 |
+
return len(self._dataloader)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def prepare_one_sample(wav_path: str, wav_processor=None, cuda_enabled=True) -> dict:
|
| 205 |
+
"""Prepare a single sample for inference.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
wav_path: Path to the audio file.
|
| 209 |
+
wav_processor: A function to process the audio file.
|
| 210 |
+
cuda_enabled: Whether to move the sample to the GPU.
|
| 211 |
+
"""
|
| 212 |
+
audio, sr = sf.read(wav_path)
|
| 213 |
+
if len(audio.shape) == 2: # stereo to mono
|
| 214 |
+
audio = audio.mean(axis=1)
|
| 215 |
+
if len(audio) < sr: # pad audio to at least 1s
|
| 216 |
+
sil = np.zeros(sr - len(audio), dtype=float)
|
| 217 |
+
audio = np.concatenate((audio, sil), axis=0)
|
| 218 |
+
audio = audio[: sr * 10] # truncate audio to at most 10s
|
| 219 |
+
|
| 220 |
+
# spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"]
|
| 221 |
+
print("audio shape", audio.shape)
|
| 222 |
+
|
| 223 |
+
audio_t = torch.tensor(audio).unsqueeze(0)
|
| 224 |
+
audio_t = torchaudio.functional.resample(audio_t, sr, TARGET_SAMPLE_RATE)
|
| 225 |
+
print("audio shape after resample", audio_t.shape)
|
| 226 |
+
|
| 227 |
+
samples = {
|
| 228 |
+
"raw_wav": audio_t,
|
| 229 |
+
"padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0),
|
| 230 |
+
"audio_chunk_sizes": [1],
|
| 231 |
+
}
|
| 232 |
+
if cuda_enabled:
|
| 233 |
+
samples = move_to_device(samples, "cuda")
|
| 234 |
+
|
| 235 |
+
return samples
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def prepare_one_sample_waveform(audio, cuda_enabled=True, sr=16000):
|
| 239 |
+
print("shape", audio.shape)
|
| 240 |
+
if len(audio.shape) == 2: # stereo to mono
|
| 241 |
+
print("converting stereo to mono?")
|
| 242 |
+
audio = audio.mean(axis=1)
|
| 243 |
+
if len(audio) < sr: # pad audio to at least 1s
|
| 244 |
+
sil = np.zeros(sr - len(audio), dtype=float)
|
| 245 |
+
audio = np.concatenate((audio, sil), axis=0)
|
| 246 |
+
audio = audio[: sr * 10] # truncate audio to at most 30s
|
| 247 |
+
|
| 248 |
+
samples = {
|
| 249 |
+
"raw_wav": torch.tensor(audio).unsqueeze(0).type(torch.DoubleTensor),
|
| 250 |
+
"padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0),
|
| 251 |
+
}
|
| 252 |
+
if cuda_enabled:
|
| 253 |
+
samples = move_to_device(samples, "cuda")
|
| 254 |
+
|
| 255 |
+
return samples
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def prepare_sample_waveforms(audio_paths, cuda_enabled=True, sr=TARGET_SAMPLE_RATE, max_length_seconds=10):
|
| 259 |
+
batch_len = sr # minimum length of audio
|
| 260 |
+
audios = []
|
| 261 |
+
for audio_path in audio_paths:
|
| 262 |
+
audio, loaded_sr = sf.read(audio_path)
|
| 263 |
+
if len(audio.shape) == 2:
|
| 264 |
+
audio = audio[:, 0]
|
| 265 |
+
audio = audio[: loaded_sr * 10]
|
| 266 |
+
audio = resampy.resample(audio, loaded_sr, sr)
|
| 267 |
+
audio = torch.from_numpy(audio)
|
| 268 |
+
|
| 269 |
+
if len(audio) < sr * max_length_seconds:
|
| 270 |
+
pad_size = sr * max_length_seconds - len(audio)
|
| 271 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 272 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 273 |
+
if len(audio) > batch_len:
|
| 274 |
+
batch_len = len(audio)
|
| 275 |
+
audios.append(audio)
|
| 276 |
+
padding_mask = torch.zeros((len(audios), batch_len), dtype=torch.bool)
|
| 277 |
+
for i in range(len(audios)):
|
| 278 |
+
if len(audios[i]) < batch_len:
|
| 279 |
+
pad_len = batch_len - len(audios[i])
|
| 280 |
+
sil = torch.zeros(pad_len, dtype=torch.float32)
|
| 281 |
+
audios[i] = torch.cat((audios[i], sil), dim=0)
|
| 282 |
+
padding_mask[i, len(audios[i]) :] = True
|
| 283 |
+
audios = torch.stack(audios, dim=0)
|
| 284 |
+
|
| 285 |
+
samples = {
|
| 286 |
+
"raw_wav": audios,
|
| 287 |
+
"padding_mask": padding_mask,
|
| 288 |
+
"audio_chunk_sizes": [len(audio_paths)],
|
| 289 |
+
}
|
| 290 |
+
if cuda_enabled:
|
| 291 |
+
samples = move_to_device(samples, "cuda")
|
| 292 |
+
|
| 293 |
+
return samples
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def generate_sample_batches(
|
| 297 |
+
audio_path,
|
| 298 |
+
cuda_enabled: bool = True,
|
| 299 |
+
sr: int = TARGET_SAMPLE_RATE,
|
| 300 |
+
chunk_len: int = 10,
|
| 301 |
+
hop_len: int = 5,
|
| 302 |
+
batch_size: int = 4,
|
| 303 |
+
):
|
| 304 |
+
audio, loaded_sr = sf.read(audio_path)
|
| 305 |
+
if len(audio.shape) == 2: # stereo to mono
|
| 306 |
+
audio = audio.mean(axis=1)
|
| 307 |
+
audio = torchaudio.functional.resample(torch.from_numpy(audio), loaded_sr, sr)
|
| 308 |
+
hop_len = hop_len * sr
|
| 309 |
+
chunk_len = max(len(audio), chunk_len * sr)
|
| 310 |
+
chunks = []
|
| 311 |
+
|
| 312 |
+
for i in range(0, len(audio), hop_len):
|
| 313 |
+
chunk = audio[i : i + chunk_len]
|
| 314 |
+
if len(chunk) < chunk_len:
|
| 315 |
+
break
|
| 316 |
+
chunks.append(chunk)
|
| 317 |
+
|
| 318 |
+
for i in range(0, len(chunks), batch_size):
|
| 319 |
+
batch = chunks[i : i + batch_size]
|
| 320 |
+
padding_mask = torch.zeros((len(batch), sr * chunk_len), dtype=torch.bool)
|
| 321 |
+
batch = torch.stack(batch, dim=0)
|
| 322 |
+
samples = {
|
| 323 |
+
"raw_wav": batch,
|
| 324 |
+
"padding_mask": padding_mask,
|
| 325 |
+
"audio_chunk_sizes": [1 for _ in range(len(batch))],
|
| 326 |
+
}
|
| 327 |
+
if cuda_enabled:
|
| 328 |
+
samples = move_to_device(samples, "cuda")
|
| 329 |
+
yield samples
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def prepare_samples_for_detection(samples, prompt, label):
|
| 333 |
+
prompts = [prompt for i in range(len(samples["raw_wav"]))]
|
| 334 |
+
labels = [label for i in range(len(samples["raw_wav"]))]
|
| 335 |
+
task = ["detection" for i in range(len(samples["raw_wav"]))]
|
| 336 |
+
samples["prompt"] = prompts
|
| 337 |
+
samples["text"] = labels
|
| 338 |
+
samples["task"] = task
|
| 339 |
+
return samples
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def universal_torch_load(
|
| 343 |
+
f: str | os.PathLike,
|
| 344 |
+
*,
|
| 345 |
+
cache_mode: Literal["none", "use", "force"] = "none",
|
| 346 |
+
**kwargs,
|
| 347 |
+
) -> Any:
|
| 348 |
+
"""
|
| 349 |
+
Wrapper function for torch.load that can handle GCS paths.
|
| 350 |
+
|
| 351 |
+
This function provides a convenient way to load PyTorch objects from both local and
|
| 352 |
+
Google Cloud Storage (GCS) paths. For GCS paths, it can optionally caches the
|
| 353 |
+
downloaded files locally to avoid repeated downloads.
|
| 354 |
+
|
| 355 |
+
The cache location is determined by:
|
| 356 |
+
1. The ESP_CACHE_HOME environment variable if set
|
| 357 |
+
2. Otherwise defaults to ~/.cache/esp/
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
f: File-like object, string or PathLike object.
|
| 361 |
+
Can be a local path or a GCS path (starting with 'gs://').
|
| 362 |
+
cache_mode (str, optional): Cache mode for GCS files. Options are:
|
| 363 |
+
"none": No caching (use bucket directly)
|
| 364 |
+
"use": Use cache if available, download if not
|
| 365 |
+
"force": Force redownload even if cache exists
|
| 366 |
+
Defaults to "none".
|
| 367 |
+
**kwargs: Additional keyword arguments passed to torch.load().
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
The object loaded from the file using torch.load.
|
| 371 |
+
|
| 372 |
+
Raises:
|
| 373 |
+
IsADirectoryError: If the GCS path points to a directory instead of a file.
|
| 374 |
+
FileNotFoundError: If the local file does not exist.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
f = Path(f)
|
| 378 |
+
if not f.exists():
|
| 379 |
+
raise FileNotFoundError(f"File does not exist: {f}")
|
| 380 |
+
|
| 381 |
+
with open(f, "rb") as opened_file:
|
| 382 |
+
return torch.load(opened_file, **kwargs)
|
README.md
CHANGED
|
@@ -1,14 +1,34 @@
|
|
| 1 |
---
|
| 2 |
-
title: NatureLM Audio
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
short_description:
|
| 12 |
---
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: NatureLM Audio Demo
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.38.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
+
short_description: Audio analysis with NatureLM model
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# NatureLM Audio Demo
|
| 15 |
+
|
| 16 |
+
This is a demo of the NatureLM audio analysis model. The app provides three main features:
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
1. **Chat Interface**: Upload audio files and ask questions about them
|
| 21 |
+
2. **Batch Processing**: Process multiple audio files with the same task
|
| 22 |
+
3. **Long Recording Analysis**: Analyze long audio recordings by chunking them
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
- **First Use**: The model will load automatically when you first use it (this may take a few minutes)
|
| 27 |
+
- **Subsequent Uses**: The model stays loaded for faster responses
|
| 28 |
+
- **Demo Mode**: If the model fails to load, the app will run in demo mode
|
| 29 |
+
|
| 30 |
+
## Model Loading
|
| 31 |
+
|
| 32 |
+
The app uses lazy loading to start quickly. The model is only loaded when you first interact with it, not during app initialization. This prevents timeout issues on HuggingFace Spaces.
|
| 33 |
+
|
| 34 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
Space.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sdk: gradio
|
| 2 |
+
python_version: 3.10
|
| 3 |
+
hardware: cpu
|
configs/inference.yml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
llama_path: "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 3 |
+
|
| 4 |
+
freeze_beats: True
|
| 5 |
+
device: "cuda"
|
| 6 |
+
use_audio_Qformer: True
|
| 7 |
+
max_pooling: False
|
| 8 |
+
downsample_factor: 8
|
| 9 |
+
freeze_audio_QFormer: False
|
| 10 |
+
window_level_Qformer: True
|
| 11 |
+
num_audio_query_token: 1
|
| 12 |
+
second_per_window: 0.333333
|
| 13 |
+
second_stride: 0.333333
|
| 14 |
+
|
| 15 |
+
audio_llama_proj_model: ""
|
| 16 |
+
freeze_audio_llama_proj: False
|
| 17 |
+
|
| 18 |
+
lora: True
|
| 19 |
+
lora_rank: 32
|
| 20 |
+
lora_alpha: 32
|
| 21 |
+
lora_dropout: 0.1
|
| 22 |
+
|
| 23 |
+
prompt_template: "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 24 |
+
max_txt_len: 160
|
| 25 |
+
end_sym: <|end_of_text|>
|
| 26 |
+
|
| 27 |
+
beats_cfg:
|
| 28 |
+
input_patch_size: 16
|
| 29 |
+
embed_dim: 512
|
| 30 |
+
conv_bias: False
|
| 31 |
+
encoder_layers: 12
|
| 32 |
+
encoder_embed_dim: 768
|
| 33 |
+
encoder_ffn_embed_dim: 3072
|
| 34 |
+
encoder_attention_heads: 12
|
| 35 |
+
activation_fn: "gelu"
|
| 36 |
+
layer_wise_gradient_decay_ratio: 0.6
|
| 37 |
+
layer_norm_first: False
|
| 38 |
+
deep_norm: True
|
| 39 |
+
dropout: 0.0
|
| 40 |
+
attention_dropout: 0.0
|
| 41 |
+
activation_dropout: 0.0
|
| 42 |
+
encoder_layerdrop: 0.05
|
| 43 |
+
dropout_input: 0.0
|
| 44 |
+
conv_pos: 128
|
| 45 |
+
conv_pos_groups: 16
|
| 46 |
+
relative_position_embedding: True
|
| 47 |
+
num_buckets: 320
|
| 48 |
+
max_distance: 800
|
| 49 |
+
gru_rel_pos: True
|
| 50 |
+
finetuned_model: True
|
| 51 |
+
predictor_dropout: 0.0
|
| 52 |
+
predictor_class: 527
|
| 53 |
+
|
| 54 |
+
generate:
|
| 55 |
+
max_new_tokens: 300
|
| 56 |
+
num_beams: 2
|
| 57 |
+
do_sample: False
|
| 58 |
+
min_length: 1
|
| 59 |
+
temperature: 0.1
|
| 60 |
+
repetition_penalty: 1.0
|
| 61 |
+
length_penalty: 1.0
|
requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.2.2
|
| 2 |
+
torchaudio>=2.2.2
|
| 3 |
+
torchvision>=0.17.2
|
| 4 |
+
transformers[sentencepiece]>=4.44.2
|
| 5 |
+
datasets>=2.20.0
|
| 6 |
+
cloudpathlib[gs]>=0.20.0
|
| 7 |
+
einops>=0.8.0
|
| 8 |
+
gradio>=5.10.0
|
| 9 |
+
google-cloud-aiplatform>=1.76.0
|
| 10 |
+
Levenshtein>=0.25.1
|
| 11 |
+
librosa>=0.9.2
|
| 12 |
+
memoization>=0.4.0
|
| 13 |
+
mir-eval>=0.7
|
| 14 |
+
numpy>=1.26.4
|
| 15 |
+
pandas>=1.4.3
|
| 16 |
+
peft>=0.11.1
|
| 17 |
+
plumbum>=1.7.2
|
| 18 |
+
pydantic-settings>=2.7.1
|
| 19 |
+
pydantic>=2.7.4
|
| 20 |
+
pydub>=0.25.1
|
| 21 |
+
pyyaml>=6.0
|
| 22 |
+
resampy>=0.3.1
|
| 23 |
+
scipy>=1.14.0
|
| 24 |
+
soundfile>=0.12.1
|
| 25 |
+
tensorboard>=2.18.0
|
| 26 |
+
tensorboardX>=2.6.2.2
|
| 27 |
+
tqdm>=4.66.4
|
| 28 |
+
wandb>=0.17.3
|
| 29 |
+
click>=8.1.7
|
| 30 |
+
git+https://github.com/earthspecies/beans-zero.git
|