|
|
import base64 |
|
|
import re |
|
|
from itertools import groupby |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple, Union, Dict, List, Any |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from transformers import ( |
|
|
Wav2Vec2BertProcessor, |
|
|
Wav2Vec2CTCTokenizer, |
|
|
Wav2Vec2BertModel, |
|
|
Wav2Vec2CTCTokenizer, |
|
|
Wav2Vec2BertPreTrainedModel, |
|
|
SeamlessM4TFeatureExtractor, |
|
|
pipeline, |
|
|
Pipeline, |
|
|
) |
|
|
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import ( |
|
|
_HIDDEN_STATES_START_POSITION, |
|
|
) |
|
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
|
import torchaudio |
|
|
|
|
|
ONSETS = { |
|
|
"b", |
|
|
"d", |
|
|
"g", |
|
|
"gw", |
|
|
"z", |
|
|
"p", |
|
|
"t", |
|
|
"k", |
|
|
"kw", |
|
|
"c", |
|
|
"m", |
|
|
"n", |
|
|
"ng", |
|
|
"f", |
|
|
"h", |
|
|
"s", |
|
|
"l", |
|
|
"w", |
|
|
"j", |
|
|
} |
|
|
|
|
|
|
|
|
class SpeechToJyutpingPipeline(Pipeline): |
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
tone_vocab_file = hf_hub_download( |
|
|
repo_id="hon9kon9ize/wav2vec2bert-jyutping", filename="tone_vocab.json" |
|
|
) |
|
|
self.tone_tokenizer = Wav2Vec2CTCTokenizer( |
|
|
tone_vocab_file, |
|
|
unk_token="[UNK]", |
|
|
pad_token="[PAD]", |
|
|
word_delimiter_token="|", |
|
|
) |
|
|
self.processor = Wav2Vec2BertProcessor( |
|
|
feature_extractor=self.feature_extractor, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
self.onset_ids = { |
|
|
self.processor.tokenizer.convert_tokens_to_ids(onset) for onset in ONSETS |
|
|
} |
|
|
preprocess_kwargs = {} |
|
|
return preprocess_kwargs, {}, {} |
|
|
|
|
|
def preprocess(self, inputs): |
|
|
waveform, original_sampling_rate = torchaudio.load(inputs) |
|
|
resampler = torchaudio.transforms.Resample( |
|
|
orig_freq=original_sampling_rate, new_freq=16000 |
|
|
) |
|
|
resampled_array = resampler(waveform).numpy().flatten() |
|
|
|
|
|
input_features = self.processor( |
|
|
resampled_array, sampling_rate=16_000, return_tensors="pt" |
|
|
).input_features |
|
|
return {"input_features": input_features.to(self.device)} |
|
|
|
|
|
def _forward(self, model_inputs): |
|
|
outputs = self.model( |
|
|
input_features=model_inputs["input_features"], |
|
|
) |
|
|
jyutping_logits = outputs.jyutping_logits |
|
|
tone_logits = outputs.tone_logits |
|
|
|
|
|
return { |
|
|
"jyutping_logits": jyutping_logits, |
|
|
"tone_logits": tone_logits, |
|
|
"duration": model_inputs["input_features"], |
|
|
} |
|
|
|
|
|
def postprocess(self, model_outputs): |
|
|
tone_logits = model_outputs["tone_logits"] |
|
|
predicted_ids = torch.argmax(model_outputs["jyutping_logits"], dim=-1) |
|
|
transcription = self.processor.batch_decode(predicted_ids)[0] |
|
|
|
|
|
sample_rate = 16000 |
|
|
symbols = [w for w in transcription.split(" ") if len(w) > 0] |
|
|
|
|
|
ids_w_index = [(i, _id.item()) for i, _id in enumerate(predicted_ids[0])] |
|
|
|
|
|
ids_w_index = [ |
|
|
i for i in ids_w_index if i[1] != self.processor.tokenizer.pad_token_id |
|
|
] |
|
|
|
|
|
split_ids_index = [ |
|
|
list(group)[0] |
|
|
for k, group in groupby( |
|
|
ids_w_index, |
|
|
lambda x: x[1] == self.processor.tokenizer.word_delimiter_token_id, |
|
|
) |
|
|
if not k |
|
|
] |
|
|
|
|
|
assert len(split_ids_index) == len( |
|
|
symbols |
|
|
) |
|
|
|
|
|
transcription = "" |
|
|
last_onset_index = -1 |
|
|
tone_probs = [] |
|
|
|
|
|
for cur_ids_w_index, cur_word in zip(split_ids_index, symbols): |
|
|
symbol_index, symbol_token_id = cur_ids_w_index |
|
|
if symbol_token_id in self.onset_ids: |
|
|
if last_onset_index > -1: |
|
|
tone_prob = torch.zeros(tone_logits.shape[-1]).to( |
|
|
tone_logits.device |
|
|
) |
|
|
for i in range(last_onset_index, symbol_index): |
|
|
tone_prob += tone_logits[0, i, :] |
|
|
tone_prob[[0, 1, 2]] = 0.0 |
|
|
tone_probs.append(tone_prob[3:].softmax(dim=-1)) |
|
|
predicted_tone_id = torch.argmax(tone_prob.softmax(dim=-1)).item() |
|
|
transcription += ( |
|
|
self.tone_tokenizer.decode([predicted_tone_id]) + "_" |
|
|
) |
|
|
transcription += "_" + cur_word |
|
|
last_onset_index = symbol_index |
|
|
else: |
|
|
transcription += cur_word |
|
|
if symbol_index == len(predicted_ids[0]) - 1: |
|
|
|
|
|
tone_prob = torch.zeros(tone_logits.shape[-1]).to(tone_logits.device) |
|
|
for i in range(last_onset_index, len(predicted_ids[0])): |
|
|
tone_prob += tone_logits[0, i, :] |
|
|
tone_prob[[0, 1, 2]] = 0.0 |
|
|
tone_probs.append(tone_prob[3:].softmax(dim=-1)) |
|
|
predicted_tone_id = torch.argmax(tone_prob.softmax(dim=-1)).item() |
|
|
transcription += self.tone_tokenizer.decode([predicted_tone_id]) + "_" |
|
|
transcription = re.sub( |
|
|
r"\s+", " ", "".join(transcription).replace("_", " ").strip() |
|
|
) |
|
|
tone_probs = torch.stack(tone_probs).cpu().tolist() |
|
|
|
|
|
return {"transcription": transcription, "tone_probs": tone_probs} |
|
|
|
|
|
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
|
"speech-to-jyutping", |
|
|
pipeline_class=SpeechToJyutpingPipeline, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class JuytpingOutput(ModelOutput): |
|
|
""" |
|
|
Output type of Wav2Vec2BertForCantonese |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
jyutping_logits: torch.FloatTensor = None |
|
|
tone_logits: torch.FloatTensor = None |
|
|
jyutping_loss: Optional[torch.FloatTensor] = None |
|
|
tone_loss: Optional[torch.FloatTensor] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
class Wav2Vec2BertForCantonese(Wav2Vec2BertPreTrainedModel): |
|
|
""" |
|
|
Wav2Vec2BertForCantonese is a Wav2Vec2BertModel with a language model head on top (a linear layer on top of the hidden-states output) that outputs Jyutping and tone logits. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config, |
|
|
tone_vocab_size: int = 9, |
|
|
): |
|
|
super().__init__(config) |
|
|
|
|
|
self.wav2vec2_bert = Wav2Vec2BertModel(config) |
|
|
self.dropout = nn.Dropout(config.final_dropout) |
|
|
self.tone_vocab_size = tone_vocab_size |
|
|
|
|
|
if config.vocab_size is None: |
|
|
raise ValueError( |
|
|
f"You are trying to instantiate {self.__class__} with a configuration that " |
|
|
"does not define the vocabulary size of the language model head. Please " |
|
|
"instantiate the model as follows: `Wav2Vec2BertForCTC.from_pretrained(..., vocab_size=vocab_size)`. " |
|
|
"or define `vocab_size` of your model's configuration." |
|
|
) |
|
|
output_hidden_size = ( |
|
|
config.output_hidden_size |
|
|
if hasattr(config, "add_adapter") and config.add_adapter |
|
|
else config.hidden_size |
|
|
) |
|
|
self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size) |
|
|
self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_features: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
jyutping_labels: Optional[torch.Tensor] = None, |
|
|
tone_labels: Optional[torch.Tensor] = None, |
|
|
) -> Union[Tuple, JuytpingOutput]: |
|
|
if ( |
|
|
jyutping_labels is not None |
|
|
and jyutping_labels.max() >= self.config.vocab_size |
|
|
): |
|
|
raise ValueError( |
|
|
f"Label values must be <= vocab_size: {self.config.vocab_size}" |
|
|
) |
|
|
|
|
|
if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size: |
|
|
raise ValueError( |
|
|
f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}" |
|
|
) |
|
|
|
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
outputs = self.wav2vec2_bert( |
|
|
input_features, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
jyutping_logits = self.jyutping_head(hidden_states) |
|
|
tone_logits = self.tone_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
jyutping_loss = None |
|
|
tone_loss = None |
|
|
|
|
|
if jyutping_labels is not None and tone_labels is not None: |
|
|
|
|
|
attention_mask = ( |
|
|
attention_mask |
|
|
if attention_mask is not None |
|
|
else torch.ones( |
|
|
input_features.shape[:2], |
|
|
device=input_features.device, |
|
|
dtype=torch.long, |
|
|
) |
|
|
) |
|
|
input_lengths = self._get_feat_extract_output_lengths( |
|
|
attention_mask.sum([-1]) |
|
|
).to(torch.long) |
|
|
|
|
|
|
|
|
|
|
|
jyutping_labels_mask = jyutping_labels >= 0 |
|
|
jyutping_target_lengths = jyutping_labels_mask.sum(-1) |
|
|
jyutping_flattened_targets = jyutping_labels.masked_select( |
|
|
jyutping_labels_mask |
|
|
) |
|
|
|
|
|
|
|
|
jyutping_log_probs = nn.functional.log_softmax( |
|
|
jyutping_logits, dim=-1, dtype=torch.float32 |
|
|
).transpose(0, 1) |
|
|
|
|
|
with torch.backends.cudnn.flags(enabled=False): |
|
|
jyutping_loss = nn.functional.ctc_loss( |
|
|
jyutping_log_probs, |
|
|
jyutping_flattened_targets, |
|
|
input_lengths, |
|
|
jyutping_target_lengths, |
|
|
blank=self.config.pad_token_id, |
|
|
reduction=self.config.ctc_loss_reduction, |
|
|
zero_infinity=self.config.ctc_zero_infinity, |
|
|
) |
|
|
|
|
|
tone_labels_mask = tone_labels >= 0 |
|
|
tone_target_lengths = tone_labels_mask.sum(-1) |
|
|
tone_flattened_targets = tone_labels.masked_select(tone_labels_mask) |
|
|
|
|
|
tone_log_probs = nn.functional.log_softmax( |
|
|
tone_logits, dim=-1, dtype=torch.float32 |
|
|
).transpose(0, 1) |
|
|
|
|
|
with torch.backends.cudnn.flags(enabled=False): |
|
|
tone_loss = nn.functional.ctc_loss( |
|
|
tone_log_probs, |
|
|
tone_flattened_targets, |
|
|
input_lengths, |
|
|
tone_target_lengths, |
|
|
blank=self.config.pad_token_id, |
|
|
reduction=self.config.ctc_loss_reduction, |
|
|
zero_infinity=self.config.ctc_zero_infinity, |
|
|
) |
|
|
|
|
|
loss = jyutping_loss + tone_loss |
|
|
|
|
|
if not return_dict: |
|
|
output = (jyutping_logits, tone_logits) + outputs[ |
|
|
_HIDDEN_STATES_START_POSITION: |
|
|
] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return JuytpingOutput( |
|
|
loss=loss, |
|
|
jyutping_logits=jyutping_logits, |
|
|
tone_logits=tone_logits, |
|
|
jyutping_loss=jyutping_loss, |
|
|
tone_loss=tone_loss, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def inference( |
|
|
self, |
|
|
processor: Wav2Vec2BertProcessor, |
|
|
tone_tokenizer: Wav2Vec2CTCTokenizer, |
|
|
input_features: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
): |
|
|
outputs = self.forward( |
|
|
input_features=input_features, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict=True, |
|
|
) |
|
|
jyutping_logits = outputs.jyutping_logits |
|
|
tone_logits = outputs.tone_logits |
|
|
jyutping_pred_ids = torch.argmax(jyutping_logits, dim=-1) |
|
|
tone_pred_ids = torch.argmax(tone_logits, dim=-1) |
|
|
jyutping_pred = processor.batch_decode(jyutping_pred_ids)[0] |
|
|
tone_pred = tone_tokenizer.batch_decode(tone_pred_ids)[0] |
|
|
jyutping_list = jyutping_pred.split(" ") |
|
|
tone_list = tone_pred.split(" ") |
|
|
jyutping_output = [] |
|
|
|
|
|
for jypt in jyutping_list: |
|
|
is_initial = jypt in ONSETS |
|
|
|
|
|
if is_initial: |
|
|
jypt = "_" + jypt |
|
|
else: |
|
|
jypt = jypt + "_" |
|
|
|
|
|
jyutping_output.append(jypt) |
|
|
|
|
|
jyutping_output = re.sub( |
|
|
r"\s+", " ", "".join(jyutping_output).replace("_", " ").strip() |
|
|
).split(" ") |
|
|
|
|
|
if len(tone_list) > len(jyutping_output): |
|
|
tone_list = tone_list[: len(jyutping_output)] |
|
|
elif len(tone_list) < len(jyutping_output): |
|
|
|
|
|
tone_list = tone_list + [tone_list[-1]] * ( |
|
|
len(jyutping_output) - len(tone_list) |
|
|
) |
|
|
|
|
|
return ( |
|
|
" ".join( |
|
|
[f"{jypt}{tone}" for jypt, tone in zip(jyutping_output, tone_list)] |
|
|
), |
|
|
jyutping_logits, |
|
|
tone_logits, |
|
|
) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path="."): |
|
|
model_path = "hon9kon9ize/wav2vec2bert-jyutping" |
|
|
feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(model_path) |
|
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_path) |
|
|
|
|
|
self.pipeline = pipeline( |
|
|
task="speech-to-jyutping", |
|
|
model=Wav2Vec2BertForCantonese.from_pretrained(model_path), |
|
|
feature_extractor=feature_extractor, |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
data args: |
|
|
inputs (:obj: `str`) |
|
|
Return: |
|
|
A :obj:`list` | `dict`: will be serialized and returned |
|
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
audio = inputs["audio"] |
|
|
audio_bytes = base64.b64decode(audio) |
|
|
temp_wav_path = "/tmp/temp.wav" |
|
|
|
|
|
with open(temp_wav_path, "wb") as f: |
|
|
f.write(audio_bytes) |
|
|
|
|
|
|
|
|
prediction = self.pipeline(temp_wav_path) |
|
|
|
|
|
return prediction |
|
|
|