import base64 import re from itertools import groupby from dataclasses import dataclass from typing import Optional, Tuple, Union, Dict, List, Any from huggingface_hub import hf_hub_download import torch import torch.nn as nn from transformers.modeling_outputs import ModelOutput from transformers import ( Wav2Vec2BertProcessor, Wav2Vec2CTCTokenizer, Wav2Vec2BertModel, Wav2Vec2CTCTokenizer, Wav2Vec2BertPreTrainedModel, SeamlessM4TFeatureExtractor, pipeline, Pipeline, ) from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import ( _HIDDEN_STATES_START_POSITION, ) from transformers.pipelines import PIPELINE_REGISTRY import torchaudio ONSETS = { "b", "d", "g", "gw", "z", "p", "t", "k", "kw", "c", "m", "n", "ng", "f", "h", "s", "l", "w", "j", } class SpeechToJyutpingPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): tone_vocab_file = hf_hub_download( repo_id="hon9kon9ize/wav2vec2bert-jyutping", filename="tone_vocab.json" ) self.tone_tokenizer = Wav2Vec2CTCTokenizer( tone_vocab_file, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", ) self.processor = Wav2Vec2BertProcessor( feature_extractor=self.feature_extractor, tokenizer=self.tokenizer, ) self.onset_ids = { self.processor.tokenizer.convert_tokens_to_ids(onset) for onset in ONSETS } preprocess_kwargs = {} return preprocess_kwargs, {}, {} def preprocess(self, inputs): waveform, original_sampling_rate = torchaudio.load(inputs) resampler = torchaudio.transforms.Resample( orig_freq=original_sampling_rate, new_freq=16000 ) resampled_array = resampler(waveform).numpy().flatten() input_features = self.processor( resampled_array, sampling_rate=16_000, return_tensors="pt" ).input_features return {"input_features": input_features.to(self.device)} def _forward(self, model_inputs): outputs = self.model( input_features=model_inputs["input_features"], ) jyutping_logits = outputs.jyutping_logits tone_logits = outputs.tone_logits return { "jyutping_logits": jyutping_logits, "tone_logits": tone_logits, "duration": model_inputs["input_features"], } def postprocess(self, model_outputs): tone_logits = model_outputs["tone_logits"] predicted_ids = torch.argmax(model_outputs["jyutping_logits"], dim=-1) transcription = self.processor.batch_decode(predicted_ids)[0] sample_rate = 16000 symbols = [w for w in transcription.split(" ") if len(w) > 0] ids_w_index = [(i, _id.item()) for i, _id in enumerate(predicted_ids[0])] # remove entries which are just "padding" (i.e. no characers are recognized) ids_w_index = [ i for i in ids_w_index if i[1] != self.processor.tokenizer.pad_token_id ] # now split the ids into groups of ids where each group represents a word split_ids_index = [ list(group)[0] for k, group in groupby( ids_w_index, lambda x: x[1] == self.processor.tokenizer.word_delimiter_token_id, ) if not k ] assert len(split_ids_index) == len( symbols ) # make sure that there are the same number of id-groups as words. Otherwise something is wrong transcription = "" last_onset_index = -1 tone_probs = [] for cur_ids_w_index, cur_word in zip(split_ids_index, symbols): symbol_index, symbol_token_id = cur_ids_w_index if symbol_token_id in self.onset_ids: if last_onset_index > -1: tone_prob = torch.zeros(tone_logits.shape[-1]).to( tone_logits.device ) for i in range(last_onset_index, symbol_index): tone_prob += tone_logits[0, i, :] tone_prob[[0, 1, 2]] = 0.0 # set padding, unknown, sep to 0 prob tone_probs.append(tone_prob[3:].softmax(dim=-1)) predicted_tone_id = torch.argmax(tone_prob.softmax(dim=-1)).item() transcription += ( self.tone_tokenizer.decode([predicted_tone_id]) + "_" ) transcription += "_" + cur_word last_onset_index = symbol_index else: transcription += cur_word if symbol_index == len(predicted_ids[0]) - 1: # last word, add tone tone_prob = torch.zeros(tone_logits.shape[-1]).to(tone_logits.device) for i in range(last_onset_index, len(predicted_ids[0])): tone_prob += tone_logits[0, i, :] tone_prob[[0, 1, 2]] = 0.0 # set padding, unknown, sep to 0 prob tone_probs.append(tone_prob[3:].softmax(dim=-1)) predicted_tone_id = torch.argmax(tone_prob.softmax(dim=-1)).item() transcription += self.tone_tokenizer.decode([predicted_tone_id]) + "_" transcription = re.sub( r"\s+", " ", "".join(transcription).replace("_", " ").strip() ) tone_probs = torch.stack(tone_probs).cpu().tolist() return {"transcription": transcription, "tone_probs": tone_probs} PIPELINE_REGISTRY.register_pipeline( "speech-to-jyutping", pipeline_class=SpeechToJyutpingPipeline, ) @dataclass class JuytpingOutput(ModelOutput): """ Output type of Wav2Vec2BertForCantonese """ loss: Optional[torch.FloatTensor] = None jyutping_logits: torch.FloatTensor = None tone_logits: torch.FloatTensor = None jyutping_loss: Optional[torch.FloatTensor] = None tone_loss: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class Wav2Vec2BertForCantonese(Wav2Vec2BertPreTrainedModel): """ Wav2Vec2BertForCantonese is a Wav2Vec2BertModel with a language model head on top (a linear layer on top of the hidden-states output) that outputs Jyutping and tone logits. """ def __init__( self, config, tone_vocab_size: int = 9, ): super().__init__(config) self.wav2vec2_bert = Wav2Vec2BertModel(config) self.dropout = nn.Dropout(config.final_dropout) self.tone_vocab_size = tone_vocab_size if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " "does not define the vocabulary size of the language model head. Please " "instantiate the model as follows: `Wav2Vec2BertForCTC.from_pretrained(..., vocab_size=vocab_size)`. " "or define `vocab_size` of your model's configuration." ) output_hidden_size = ( config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size ) self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size) self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size) # Initialize weights and apply final processing self.post_init() def forward( self, input_features: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, jyutping_labels: Optional[torch.Tensor] = None, tone_labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, JuytpingOutput]: if ( jyutping_labels is not None and jyutping_labels.max() >= self.config.vocab_size ): raise ValueError( f"Label values must be <= vocab_size: {self.config.vocab_size}" ) if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size: raise ValueError( f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}" ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.wav2vec2_bert( input_features, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) jyutping_logits = self.jyutping_head(hidden_states) tone_logits = self.tone_head(hidden_states) loss = None jyutping_loss = None tone_loss = None if jyutping_labels is not None and tone_labels is not None: # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones( input_features.shape[:2], device=input_features.device, dtype=torch.long, ) ) input_lengths = self._get_feat_extract_output_lengths( attention_mask.sum([-1]) ).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to jyutping_labels_mask = jyutping_labels >= 0 jyutping_target_lengths = jyutping_labels_mask.sum(-1) jyutping_flattened_targets = jyutping_labels.masked_select( jyutping_labels_mask ) # ctc_loss doesn't support fp16 jyutping_log_probs = nn.functional.log_softmax( jyutping_logits, dim=-1, dtype=torch.float32 ).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): jyutping_loss = nn.functional.ctc_loss( jyutping_log_probs, jyutping_flattened_targets, input_lengths, jyutping_target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) tone_labels_mask = tone_labels >= 0 tone_target_lengths = tone_labels_mask.sum(-1) tone_flattened_targets = tone_labels.masked_select(tone_labels_mask) tone_log_probs = nn.functional.log_softmax( tone_logits, dim=-1, dtype=torch.float32 ).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): tone_loss = nn.functional.ctc_loss( tone_log_probs, tone_flattened_targets, input_lengths, tone_target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) loss = jyutping_loss + tone_loss if not return_dict: output = (jyutping_logits, tone_logits) + outputs[ _HIDDEN_STATES_START_POSITION: ] return ((loss,) + output) if loss is not None else output return JuytpingOutput( loss=loss, jyutping_logits=jyutping_logits, tone_logits=tone_logits, jyutping_loss=jyutping_loss, tone_loss=tone_loss, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def inference( self, processor: Wav2Vec2BertProcessor, tone_tokenizer: Wav2Vec2CTCTokenizer, input_features: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): outputs = self.forward( input_features=input_features, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False, return_dict=True, ) jyutping_logits = outputs.jyutping_logits tone_logits = outputs.tone_logits jyutping_pred_ids = torch.argmax(jyutping_logits, dim=-1) tone_pred_ids = torch.argmax(tone_logits, dim=-1) jyutping_pred = processor.batch_decode(jyutping_pred_ids)[0] tone_pred = tone_tokenizer.batch_decode(tone_pred_ids)[0] jyutping_list = jyutping_pred.split(" ") tone_list = tone_pred.split(" ") jyutping_output = [] for jypt in jyutping_list: is_initial = jypt in ONSETS if is_initial: jypt = "_" + jypt else: jypt = jypt + "_" jyutping_output.append(jypt) jyutping_output = re.sub( r"\s+", " ", "".join(jyutping_output).replace("_", " ").strip() ).split(" ") if len(tone_list) > len(jyutping_output): tone_list = tone_list[: len(jyutping_output)] elif len(tone_list) < len(jyutping_output): # repeat the last tone if the length of tone list is shorter than the length of jyutping list tone_list = tone_list + [tone_list[-1]] * ( len(jyutping_output) - len(tone_list) ) return ( " ".join( [f"{jypt}{tone}" for jypt, tone in zip(jyutping_output, tone_list)] ), jyutping_logits, tone_logits, ) class EndpointHandler: def __init__(self, path="."): model_path = "hon9kon9ize/wav2vec2bert-jyutping" feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(model_path) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_path) self.pipeline = pipeline( task="speech-to-jyutping", model=Wav2Vec2BertForCantonese.from_pretrained(model_path), feature_extractor=feature_extractor, tokenizer=tokenizer, ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs, assuming a base64 encoded wav file inputs = data.pop("inputs", data) # decode base64 file and save to temp file audio = inputs["audio"] audio_bytes = base64.b64decode(audio) temp_wav_path = "/tmp/temp.wav" with open(temp_wav_path, "wb") as f: f.write(audio_bytes) # run normal prediction prediction = self.pipeline(temp_wav_path) return prediction