from __future__ import annotations import os from functools import lru_cache from typing import Final, List, Tuple import nltk from nltk import pos_tag as _pos_tag from nltk import sent_tokenize as _sent_tokenize from nltk import word_tokenize as _word_tokenize CACHE_MAX_SIZE: Final[int] = 128 def check_for_nltk_package(package_name: str, package_category: str) -> bool: """Checks to see if the specified NLTK package exists on the image.""" paths: list[str] = [] for path in nltk.data.path: if not path.endswith("nltk_data"): path = os.path.join(path, "nltk_data") paths.append(path) try: nltk.find(f"{package_category}/{package_name}", paths=paths) return True except (LookupError, OSError): return False def download_nltk_packages(): """If required NLTK packages are not available, download them.""" tagger_available = check_for_nltk_package( package_category="taggers", package_name="averaged_perceptron_tagger_eng", ) tokenizer_available = check_for_nltk_package( package_category="tokenizers", package_name="punkt_tab" ) if (not tokenizer_available) or (not tagger_available): nltk.download("averaged_perceptron_tagger_eng", quiet=True) nltk.download("punkt_tab", quiet=True) # auto download nltk packages if the environment variable is set if os.getenv("AUTO_DOWNLOAD_NLTK", "True").lower() == "true": download_nltk_packages() @lru_cache(maxsize=CACHE_MAX_SIZE) def sent_tokenize(text: str) -> List[str]: """A wrapper around the NLTK sentence tokenizer with LRU caching enabled.""" return _sent_tokenize(text) @lru_cache(maxsize=CACHE_MAX_SIZE) def word_tokenize(text: str) -> List[str]: """A wrapper around the NLTK word tokenizer with LRU caching enabled.""" return _word_tokenize(text) @lru_cache(maxsize=CACHE_MAX_SIZE) def pos_tag(text: str) -> List[Tuple[str, str]]: """A wrapper around the NLTK POS tagger with LRU caching enabled.""" # Splitting into sentences before tokenizing. sentences = _sent_tokenize(text) parts_of_speech: list[tuple[str, str]] = [] for sentence in sentences: tokens = _word_tokenize(sentence) parts_of_speech.extend(_pos_tag(tokens)) return parts_of_speech
Memory