import warnings from typing import List, Optional import langdetect from transformers import MarianMTModel, MarianTokenizer from unstructured.nlp.tokenize import sent_tokenize from unstructured.staging.huggingface import chunk_by_attention_window def _get_opus_mt_model_name(source_lang: str, target_lang: str): """Constructs the name of the MarianMT machine translation model based on the source and target language.""" return f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" def _validate_language_code(language_code: str): if not isinstance(language_code, str) or len(language_code) != 2: raise ValueError( f"Invalid language code: {language_code}. Language codes must be two letter strings.", ) def translate_text(text: str, source_lang: Optional[str] = None, target_lang: str = "en") -> str: """Translates the foreign language text. If the source language is not specified, the function will attempt to detect it using langdetect. Parameters ---------- text: str The text to translate target_lang: str The two letter language code for the target langague. Defaults to "en". source_lang: Optional[str] The two letter language code for the language of the input text. If source_lang is not provided, the function will try to detect it. """ if text.strip() == "": return text _source_lang: str = source_lang if source_lang is not None else langdetect.detect(text) # NOTE(robinson) - Chinese gets detected with codes zh-cn, zh-tw, zh-hk for various # Chinese variants. We normalizes these because there is a single model for Chinese # machine translation if _source_lang.startswith("zh"): _source_lang = "zh" _validate_language_code(target_lang) _validate_language_code(_source_lang) if target_lang == _source_lang: return text model_name = _get_opus_mt_model_name(_source_lang, target_lang) try: tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) except OSError: raise ValueError( f"Transformers could not find the translation model {model_name}. " "The requested source/target language combo is not supported.", ) chunks: List[str] = chunk_by_attention_window(text, tokenizer, split_function=sent_tokenize) translated_chunks: List[str] = [] for chunk in chunks: translated_chunks.append(_translate_text(text, model, tokenizer)) return " ".join(translated_chunks) def _translate_text(text, model, tokenizer): """Translates text using the specified model and tokenizer.""" # NOTE(robinson) - Suppresses the HuggingFace UserWarning resulting from the "max_length" # key in the MarianMT config. The warning states that "max_length" will be deprecated # in transformers v5 with warnings.catch_warnings(): warnings.simplefilter("ignore") translated = model.generate( **tokenizer([text], return_tensors="pt", padding="max_length", max_length=512), ) return [tokenizer.decode(t, max_new_tokens=512, skip_special_tokens=True) for t in translated][ 0 ]
Memory