Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # LASER Language-Agnostic SEntence Representations | |
| # is a toolkit to calculate multilingual sentence embeddings | |
| # and to use them for document classification, bitext filtering | |
| # and mining | |
| # | |
| # -------------------------------------------------------- | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import warnings | |
| from collections import namedtuple | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from fairseq.data.dictionary import Dictionary | |
| from fairseq.models.transformer import Embedding, TransformerEncoder | |
| from fairseq.modules import LayerNorm | |
| from laser_encoders.download_models import LaserModelDownloader | |
| from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE | |
| from laser_encoders.laser_tokenizer import LaserTokenizer, initialize_tokenizer | |
| SPACE_NORMALIZER = re.compile(r"\s+") | |
| Batch = namedtuple("Batch", "srcs tokens lengths") | |
| logging.basicConfig( | |
| stream=sys.stdout, | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| ) | |
| logger = logging.getLogger("embed") | |
| class SentenceEncoder: | |
| def __init__( | |
| self, | |
| model_path, | |
| max_sentences=None, | |
| max_tokens=None, | |
| spm_vocab=None, | |
| spm_model=None, | |
| cpu=False, | |
| fp16=False, | |
| verbose=False, | |
| sort_kind="quicksort", | |
| ): | |
| if verbose: | |
| logger.info(f"loading encoder: {model_path}") | |
| self.spm_model = spm_model | |
| if self.spm_model: | |
| self.tokenizer = LaserTokenizer(spm_model=Path(self.spm_model)) | |
| self.use_cuda = torch.cuda.is_available() and not cpu | |
| self.max_sentences = max_sentences | |
| self.max_tokens = max_tokens | |
| if self.max_tokens is None and self.max_sentences is None: | |
| self.max_sentences = 1 | |
| state_dict = torch.load(model_path) | |
| if "params" in state_dict: | |
| self.encoder = LaserLstmEncoder(**state_dict["params"]) | |
| self.encoder.load_state_dict(state_dict["model"]) | |
| self.dictionary = state_dict["dictionary"] | |
| self.prepend_bos = False | |
| self.left_padding = False | |
| else: | |
| self.encoder = LaserTransformerEncoder(state_dict, spm_vocab) | |
| self.dictionary = self.encoder.dictionary.indices | |
| self.prepend_bos = state_dict["cfg"]["model"].prepend_bos | |
| self.left_padding = state_dict["cfg"]["model"].left_pad_source | |
| del state_dict | |
| self.bos_index = self.dictionary["<s>"] = 0 | |
| self.pad_index = self.dictionary["<pad>"] = 1 | |
| self.eos_index = self.dictionary["</s>"] = 2 | |
| self.unk_index = self.dictionary["<unk>"] = 3 | |
| if fp16: | |
| self.encoder.half() | |
| if self.use_cuda: | |
| if verbose: | |
| logger.info("transfer encoder to GPU") | |
| self.encoder.cuda() | |
| self.encoder.eval() | |
| self.sort_kind = sort_kind | |
| def __call__(self, text_or_batch): | |
| if self.spm_model: | |
| text_or_batch = self.tokenizer(text_or_batch) | |
| if isinstance(text_or_batch, str): | |
| text_or_batch = [text_or_batch] | |
| return self.encode_sentences(text_or_batch) | |
| else: | |
| raise ValueError( | |
| "Either initialize the encoder with an spm_model or pre-tokenize and use the encode_sentences method." | |
| ) | |
| def _process_batch(self, batch): | |
| tokens = batch.tokens | |
| lengths = batch.lengths | |
| if self.use_cuda: | |
| tokens = tokens.cuda() | |
| lengths = lengths.cuda() | |
| with torch.no_grad(): | |
| sentemb = self.encoder(tokens, lengths)["sentemb"] | |
| embeddings = sentemb.detach().cpu().numpy() | |
| return embeddings | |
| def _tokenize(self, line): | |
| tokens = SPACE_NORMALIZER.sub(" ", line).strip().split() | |
| ntokens = len(tokens) | |
| if self.prepend_bos: | |
| ids = torch.LongTensor(ntokens + 2) | |
| ids[0] = self.bos_index | |
| for i, token in enumerate(tokens): | |
| ids[i + 1] = self.dictionary.get(token, self.unk_index) | |
| ids[ntokens + 1] = self.eos_index | |
| else: | |
| ids = torch.LongTensor(ntokens + 1) | |
| for i, token in enumerate(tokens): | |
| ids[i] = self.dictionary.get(token, self.unk_index) | |
| ids[ntokens] = self.eos_index | |
| return ids | |
| def _make_batches(self, lines): | |
| tokens = [self._tokenize(line) for line in lines] | |
| lengths = np.array([t.numel() for t in tokens]) | |
| indices = np.argsort(-lengths, kind=self.sort_kind) | |
| def batch(tokens, lengths, indices): | |
| toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]), self.pad_index) | |
| if not self.left_padding: | |
| for i in range(len(tokens)): | |
| toks[i, : tokens[i].shape[0]] = tokens[i] | |
| else: | |
| for i in range(len(tokens)): | |
| toks[i, -tokens[i].shape[0] :] = tokens[i] | |
| return ( | |
| Batch(srcs=None, tokens=toks, lengths=torch.LongTensor(lengths)), | |
| indices, | |
| ) | |
| batch_tokens, batch_lengths, batch_indices = [], [], [] | |
| ntokens = nsentences = 0 | |
| for i in indices: | |
| if nsentences > 0 and ( | |
| (self.max_tokens is not None and ntokens + lengths[i] > self.max_tokens) | |
| or (self.max_sentences is not None and nsentences == self.max_sentences) | |
| ): | |
| yield batch(batch_tokens, batch_lengths, batch_indices) | |
| ntokens = nsentences = 0 | |
| batch_tokens, batch_lengths, batch_indices = [], [], [] | |
| batch_tokens.append(tokens[i]) | |
| batch_lengths.append(lengths[i]) | |
| batch_indices.append(i) | |
| ntokens += tokens[i].shape[0] | |
| nsentences += 1 | |
| if nsentences > 0: | |
| yield batch(batch_tokens, batch_lengths, batch_indices) | |
| def encode_sentences(self, sentences, normalize_embeddings=False): | |
| indices = [] | |
| results = [] | |
| for batch, batch_indices in self._make_batches(sentences): | |
| indices.extend(batch_indices) | |
| encoded_batch = self._process_batch(batch) | |
| if normalize_embeddings: | |
| # Perform L2 normalization on the embeddings | |
| norms = np.linalg.norm(encoded_batch, axis=1, keepdims=True) | |
| encoded_batch = encoded_batch / norms | |
| results.append(encoded_batch) | |
| return np.vstack(results)[np.argsort(indices, kind=self.sort_kind)] | |
| class LaserTransformerEncoder(TransformerEncoder): | |
| def __init__(self, state_dict, vocab_path): | |
| self.dictionary = Dictionary.load(vocab_path) | |
| if any( | |
| k in state_dict["model"] | |
| for k in ["encoder.layer_norm.weight", "layer_norm.weight"] | |
| ): | |
| self.dictionary.add_symbol("<mask>") | |
| cfg = state_dict["cfg"]["model"] | |
| self.sentemb_criterion = cfg.sentemb_criterion | |
| self.pad_idx = self.dictionary.pad_index | |
| self.bos_idx = self.dictionary.bos_index | |
| embed_tokens = Embedding( | |
| len(self.dictionary), | |
| cfg.encoder_embed_dim, | |
| self.pad_idx, | |
| ) | |
| super().__init__(cfg, self.dictionary, embed_tokens) | |
| if "decoder.version" in state_dict["model"]: | |
| self._remove_decoder_layers(state_dict) | |
| if "layer_norm.weight" in state_dict["model"]: | |
| self.layer_norm = LayerNorm(cfg.encoder_embed_dim) | |
| self.load_state_dict(state_dict["model"]) | |
| def _remove_decoder_layers(self, state_dict): | |
| for key in list(state_dict["model"].keys()): | |
| if not key.startswith( | |
| ( | |
| "encoder.layer_norm", | |
| "encoder.layers", | |
| "encoder.embed", | |
| "encoder.version", | |
| ) | |
| ): | |
| del state_dict["model"][key] | |
| else: | |
| renamed_key = key.replace("encoder.", "") | |
| state_dict["model"][renamed_key] = state_dict["model"].pop(key) | |
| def forward(self, src_tokens, src_lengths): | |
| encoder_out = super().forward(src_tokens, src_lengths) | |
| if isinstance(encoder_out, dict): | |
| x = encoder_out["encoder_out"][0] # T x B x C | |
| else: | |
| x = encoder_out[0] | |
| if self.sentemb_criterion == "cls": | |
| cls_indices = src_tokens.eq(self.bos_idx).t() | |
| sentemb = x[cls_indices, :] | |
| else: | |
| padding_mask = src_tokens.eq(self.pad_idx).t().unsqueeze(-1) | |
| if padding_mask.any(): | |
| x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) | |
| sentemb = x.max(dim=0)[0] | |
| return {"sentemb": sentemb} | |
| class LaserLstmEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| num_embeddings, | |
| padding_idx, | |
| embed_dim=320, | |
| hidden_size=512, | |
| num_layers=1, | |
| bidirectional=False, | |
| left_pad=True, | |
| padding_value=0.0, | |
| ): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.bidirectional = bidirectional | |
| self.hidden_size = hidden_size | |
| self.padding_idx = padding_idx | |
| self.embed_tokens = nn.Embedding( | |
| num_embeddings, embed_dim, padding_idx=self.padding_idx | |
| ) | |
| self.lstm = nn.LSTM( | |
| input_size=embed_dim, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| bidirectional=bidirectional, | |
| ) | |
| self.left_pad = left_pad | |
| self.padding_value = padding_value | |
| self.output_units = hidden_size | |
| if bidirectional: | |
| self.output_units *= 2 | |
| def forward(self, src_tokens, src_lengths): | |
| bsz, seqlen = src_tokens.size() | |
| # embed tokens | |
| x = self.embed_tokens(src_tokens) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| # pack embedded source tokens into a PackedSequence | |
| packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) | |
| # apply LSTM | |
| if self.bidirectional: | |
| state_size = 2 * self.num_layers, bsz, self.hidden_size | |
| else: | |
| state_size = self.num_layers, bsz, self.hidden_size | |
| h0 = x.data.new(*state_size).zero_() | |
| c0 = x.data.new(*state_size).zero_() | |
| packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) | |
| # unpack outputs and apply dropout | |
| x, _ = nn.utils.rnn.pad_packed_sequence( | |
| packed_outs, padding_value=self.padding_value | |
| ) | |
| assert list(x.size()) == [seqlen, bsz, self.output_units] | |
| if self.bidirectional: | |
| def combine_bidir(outs): | |
| return torch.cat( | |
| [ | |
| torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( | |
| 1, bsz, self.output_units | |
| ) | |
| for i in range(self.num_layers) | |
| ], | |
| dim=0, | |
| ) | |
| final_hiddens = combine_bidir(final_hiddens) | |
| final_cells = combine_bidir(final_cells) | |
| encoder_padding_mask = src_tokens.eq(self.padding_idx).t() | |
| # Set padded outputs to -inf so they are not selected by max-pooling | |
| padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) | |
| if padding_mask.any(): | |
| x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) | |
| # Build the sentence embedding by max-pooling over the encoder outputs | |
| sentemb = x.max(dim=0)[0] | |
| return { | |
| "sentemb": sentemb, | |
| "encoder_out": (x, final_hiddens, final_cells), | |
| "encoder_padding_mask": encoder_padding_mask | |
| if encoder_padding_mask.any() | |
| else None, | |
| } | |
| def initialize_encoder( | |
| lang: str = None, | |
| model_dir: str = None, | |
| spm: bool = True, | |
| laser: str = None, | |
| ): | |
| downloader = LaserModelDownloader(model_dir) | |
| if laser is not None: | |
| if laser == "laser3": | |
| lang = downloader.get_language_code(LASER3_LANGUAGE, lang) | |
| downloader.download_laser3(lang=lang, spm=spm) | |
| file_path = f"laser3-{lang}.v1" | |
| elif laser == "laser2": | |
| downloader.download_laser2() | |
| file_path = "laser2" | |
| else: | |
| raise ValueError( | |
| f"Unsupported laser model: {laser}. Choose either laser2 or laser3." | |
| ) | |
| else: | |
| if lang in LASER3_LANGUAGE: | |
| lang = downloader.get_language_code(LASER3_LANGUAGE, lang) | |
| downloader.download_laser3(lang=lang, spm=spm) | |
| file_path = f"laser3-{lang}.v1" | |
| elif lang in LASER2_LANGUAGE: | |
| downloader.download_laser2() | |
| file_path = "laser2" | |
| else: | |
| raise ValueError( | |
| f"Unsupported language name: {lang}. Please specify a supported language name." | |
| ) | |
| model_dir = downloader.model_dir | |
| model_path = os.path.join(model_dir, f"{file_path}.pt") | |
| spm_vocab = os.path.join(model_dir, f"{file_path}.cvocab") | |
| if not os.path.exists(spm_vocab): | |
| # if there is no cvocab for the laser3 lang use laser2 cvocab | |
| spm_vocab = os.path.join(model_dir, "laser2.cvocab") | |
| return SentenceEncoder(model_path=model_path, spm_vocab=spm_vocab, spm_model=None) | |
| class LaserEncoderPipeline: | |
| def __init__( | |
| self, | |
| lang: str = None, | |
| model_dir: str = None, | |
| spm: bool = True, | |
| laser: str = None, | |
| ): | |
| if laser == "laser2" and lang is not None: | |
| warnings.warn( | |
| "Warning: The 'lang' parameter is optional when using 'laser2'. It will be ignored." | |
| ) | |
| if laser == "laser3" and lang is None: | |
| raise ValueError("For 'laser3', the 'lang' parameter is required.") | |
| if laser is None and lang is None: | |
| raise ValueError("Either 'laser' or 'lang' should be provided.") | |
| self.tokenizer = initialize_tokenizer( | |
| lang=lang, model_dir=model_dir, laser=laser | |
| ) | |
| self.encoder = initialize_encoder( | |
| lang=lang, model_dir=model_dir, spm=spm, laser=laser | |
| ) | |
| def encode_sentences( | |
| self, sentences: list, normalize_embeddings: bool = False | |
| ) -> list: | |
| """ | |
| Tokenizes and encodes a list of sentences. | |
| Args: | |
| - sentences (list of str): List of sentences to tokenize and encode. | |
| Returns: | |
| - List of embeddings for each sentence. | |
| """ | |
| tokenized_sentences = [ | |
| self.tokenizer.tokenize(sentence) for sentence in sentences | |
| ] | |
| return self.encoder.encode_sentences(tokenized_sentences, normalize_embeddings) | |