Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # LASER Language-Agnostic SEntence Representations | |
| # is a toolkit to calculate multilingual sentence embeddings | |
| # and to use them for document classification, bitext filtering | |
| # and mining | |
| # | |
| # -------------------------------------------------------- | |
| # | |
| # Helper functions for tokenization | |
| import gzip | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| from pathlib import Path | |
| from typing import IO, List | |
| import sentencepiece as spm | |
| from sacremoses import MosesDetokenizer, MosesPunctNormalizer | |
| from unicategories import categories | |
| from laser_encoders.download_models import LaserModelDownloader | |
| from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE, SPM_LANGUAGE | |
| SPACE_NORMALIZER = re.compile(r"\s+") | |
| NON_PRINT_CHARS = set(c for c in categories["C"].characters()) | |
| logging.basicConfig( | |
| stream=sys.stdout, | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| ) | |
| logger = logging.getLogger("preprocess") | |
| class LaserTokenizer: | |
| def __init__( | |
| self, | |
| spm_model: Path, | |
| lang: str = "en", | |
| lower_case: bool = True, | |
| descape: bool = False, | |
| verbose: bool = False, | |
| over_write: bool = False, | |
| normalize_punct: bool = True, | |
| ): | |
| self.spm_model = spm_model | |
| self.lang = lang | |
| self.lower_case = lower_case | |
| self.descape = descape | |
| self.verbose = verbose | |
| self.over_write = over_write | |
| self.normalize_punct = normalize_punct | |
| assert spm_model.exists(), f"spm model file: {spm_model} does not exist" | |
| self.moses_punct_normalizer = MosesPunctNormalizer(self.lang, perl_parity=True) | |
| # add parity with MOSES release-4.0 | |
| self.moses_punct_normalizer.substitutions[21] = ("β", r'"') | |
| self.moses_punct_normalizer.substitutions[22] = ("β", r'"') | |
| self.moses_detokenizer = MosesDetokenizer() | |
| self.spm_encoder = spm.SentencePieceProcessor(model_file=str(self.spm_model)) | |
| def open(self, file: Path, mode: str, encoding="utf-8") -> IO: | |
| return ( | |
| gzip.open(file, mode, encoding=encoding) | |
| if file.name.endswith(".gz") | |
| else open(file, mode, encoding=encoding) | |
| ) | |
| def log(self, message: str) -> None: | |
| if self.verbose: | |
| logger.info(message) | |
| def tokenize(self, text: str) -> str: | |
| # Preprocessing | |
| sentence_text = "".join([c if c not in NON_PRINT_CHARS else " " for c in text]) | |
| if self.normalize_punct: | |
| sentence_text = self.moses_punct_normalizer.normalize(sentence_text) | |
| if self.descape: | |
| sentence_text = self.moses_detokenizer.unescape_xml(text=sentence_text) | |
| if self.lower_case: | |
| sentence_text = sentence_text.lower() | |
| # SentencePiece encoding | |
| encoded_text = " ".join(self.spm_encoder.encode(sentence_text, out_type=str)) | |
| return encoded_text | |
| def tokenize_file(self, inp_fname: Path, out_fname: Path) -> None: | |
| if not self.over_write and out_fname.exists(): | |
| self.log(f"tokenized file {out_fname.name} already exists") | |
| return | |
| else: | |
| self.log( | |
| f"tokenizing {inp_fname.name}" | |
| + f"{' (de-escaped)' if self.descape else ''}" | |
| + f"{' (lower-cased)' if self.lower_case else ' (cased)'} " | |
| + f"(punctuation-normalization lang: {self.lang})" | |
| ) | |
| with self.open(inp_fname, "rt") as file_in, open( | |
| out_fname, "w" | |
| ) as file_out: | |
| for line in file_in: | |
| tokens = self.tokenize(line.strip()) | |
| file_out.write(tokens + "\n") | |
| def __call__(self, text_or_batch): | |
| if isinstance(text_or_batch, str): | |
| return self.tokenize(text_or_batch) | |
| else: | |
| return self.tokenize_batch(text_or_batch) | |
| def tokenize_batch(self, batch: List[str]) -> List[List[str]]: | |
| return [self.tokenize(text) for text in batch] | |
| def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: | |
| return [self.spm_encoder.DecodeIds(ids) for ids in ids] | |
| def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: | |
| ids = [] | |
| for token in tokens: | |
| # Apply the same tokenization logic as in _tokenize method | |
| tokens = SPACE_NORMALIZER.sub(" ", token).strip().split() | |
| # Initialize an empty tensor for this token's IDs | |
| token_ids = [] | |
| for i, token in enumerate(tokens): | |
| token_id = self.spm_encoder.PieceToId(token) | |
| if token_id == 0: # Handle out-of-vocabulary tokens | |
| token_id = self.spm_encoder.PieceToId("<unk>") | |
| token_ids.append(token_id) | |
| # Append token IDs to the final IDs tensor | |
| ids.extend(token_ids) | |
| return ids | |
| def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = None): | |
| downloader = LaserModelDownloader(model_dir) | |
| if laser is not None: | |
| if laser == "laser3": | |
| lang = downloader.get_language_code(LASER3_LANGUAGE, lang) | |
| if lang in SPM_LANGUAGE: | |
| filename = f"laser3-{lang}.v1.spm" | |
| else: | |
| filename = "laser2.spm" | |
| elif laser == "laser2": | |
| filename = "laser2.spm" | |
| else: | |
| raise ValueError( | |
| f"Unsupported laser model: {laser}. Choose either laser2 or laser3." | |
| ) | |
| else: | |
| if lang in LASER3_LANGUAGE: | |
| lang = downloader.get_language_code(LASER3_LANGUAGE, lang) | |
| if lang in SPM_LANGUAGE: | |
| filename = f"laser3-{lang}.v1.spm" | |
| else: | |
| filename = "laser2.spm" | |
| elif lang in LASER2_LANGUAGE: | |
| filename = "laser2.spm" | |
| else: | |
| raise ValueError( | |
| f"Unsupported language name: {lang}. Please specify a supported language name." | |
| ) | |
| downloader.download(filename) | |
| model_path = os.path.join(downloader.model_dir, filename) | |
| return LaserTokenizer(spm_model=Path(model_path)) | |