import torch # from transformers import BertTokenizerFast from colbert.modeling.hf_colbert import class_factory from colbert.infra import ColBERTConfig from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length, _insert_prefix_token from colbert.parameters import DEVICE class DocTokenizer(): def __init__(self, config: ColBERTConfig): HF_ColBERT = class_factory(config.checkpoint) self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint) self.config = config self.doc_maxlen = config.doc_maxlen self.D_marker_token, self.D_marker_token_id = self.config.doc_token, self.tok.convert_tokens_to_ids(self.config.doc_token_id) self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id def tokenize(self, batch_text, add_special_tokens=False): assert type(batch_text) in [list, tuple], (type(batch_text)) tokens = [self.tok.tokenize(x, add_special_tokens=False).to(DEVICE) for x in batch_text] if not add_special_tokens: return tokens prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token] tokens = [prefix + lst + suffix for lst in tokens] return tokens def encode(self, batch_text, add_special_tokens=False): assert type(batch_text) in [list, tuple], (type(batch_text)) ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids'] if not add_special_tokens: return ids prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id] ids = [prefix + lst + suffix for lst in ids] return ids def tensorize(self, batch_text, bsize=None): assert type(batch_text) in [list, tuple], (type(batch_text)) obj = self.tok(batch_text, padding='longest', truncation='longest_first', return_tensors='pt', max_length=(self.doc_maxlen - 1)).to(DEVICE) ids = _insert_prefix_token(obj['input_ids'], self.D_marker_token_id) mask = _insert_prefix_token(obj['attention_mask'], 1) if bsize: ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize) batches = _split_into_batches(ids, mask, bsize) return batches, reverse_indices return ids, mask
Memory