import os import tqdm import time import ujson import torch import random try: import faiss except ImportError as e: print("WARNING: faiss must be imported for indexing") import numpy as np import torch.multiprocessing as mp from colbert.infra.config.config import ColBERTConfig import colbert.utils.distributed as distributed from colbert.infra.run import Run from colbert.infra.launcher import print_memory_stats from colbert.modeling.checkpoint import Checkpoint from colbert.data.collection import Collection from colbert.indexing.collection_encoder import CollectionEncoder from colbert.indexing.index_saver import IndexSaver from colbert.indexing.utils import optimize_ivf from colbert.utils.utils import flatten, print_message from colbert.indexing.codecs.residual import ResidualCodec def encode(config, collection, shared_lists, shared_queues, verbose: int = 3): encoder = CollectionIndexer(config=config, collection=collection, verbose=verbose) encoder.run(shared_lists) class CollectionIndexer(): ''' Given a collection and config, encode collection into index and stores the index on the disk in chunks. ''' def __init__(self, config: ColBERTConfig, collection, verbose=2): self.verbose = verbose self.config = config self.rank, self.nranks = self.config.rank, self.config.nranks self.use_gpu = self.config.total_visible_gpus > 0 if self.config.rank == 0 and self.verbose > 1: self.config.help() self.collection = Collection.cast(collection) self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config) if self.use_gpu: self.checkpoint = self.checkpoint.cuda() self.encoder = CollectionEncoder(config, self.checkpoint) self.saver = IndexSaver(config) print_memory_stats(f'RANK:{self.rank}') def run(self, shared_lists): with torch.inference_mode(): self.setup() # Computes and saves plan for whole collection distributed.barrier(self.rank) print_memory_stats(f'RANK:{self.rank}') if not self.config.resume or not self.saver.try_load_codec(): self.train(shared_lists) # Trains centroids from selected passages distributed.barrier(self.rank) print_memory_stats(f'RANK:{self.rank}') self.index() # Encodes and saves all tokens into residuals distributed.barrier(self.rank) print_memory_stats(f'RANK:{self.rank}') self.finalize() # Builds metadata and centroid to passage mapping distributed.barrier(self.rank) print_memory_stats(f'RANK:{self.rank}') def setup(self): ''' Calculates and saves plan.json for the whole collection. plan.json { config, num_chunks, num_partitions, num_embeddings_est, avg_doclen_est} num_partitions is the number of centroids to be generated. ''' if self.config.resume: if self._try_load_plan(): if self.verbose > 1: Run().print_main(f"#> Loaded plan from {self.plan_path}:") Run().print_main(f"#> num_chunks = {self.num_chunks}") Run().print_main(f"#> num_partitions = {self.num_chunks}") Run().print_main(f"#> num_embeddings_est = {self.num_embeddings_est}") Run().print_main(f"#> avg_doclen_est = {self.avg_doclen_est}") return self.num_chunks = int(np.ceil(len(self.collection) / self.collection.get_chunksize())) # Saves sampled passages and embeddings for training k-means centroids later sampled_pids = self._sample_pids() avg_doclen_est = self._sample_embeddings(sampled_pids) # Select the number of partitions num_passages = len(self.collection) self.num_embeddings_est = num_passages * avg_doclen_est self.num_partitions = int(2 ** np.floor(np.log2(16 * np.sqrt(self.num_embeddings_est)))) if self.verbose > 0: Run().print_main(f'Creating {self.num_partitions:,} partitions.') Run().print_main(f'*Estimated* {int(self.num_embeddings_est):,} embeddings.') self._save_plan() def _sample_pids(self): num_passages = len(self.collection) # Simple alternative: < 100k: 100%, < 1M: 15%, < 10M: 7%, < 100M: 3%, > 100M: 1% # Keep in mind that, say, 15% still means at least 100k. # So the formula is max(100% * min(total, 100k), 15% * min(total, 1M), ...) # Then we subsample the vectors to 100 * num_partitions typical_doclen = 120 # let's keep sampling independent of the actual doc_maxlen sampled_pids = 16 * np.sqrt(typical_doclen * num_passages) # sampled_pids = int(2 ** np.floor(np.log2(1 + sampled_pids))) sampled_pids = min(1 + int(sampled_pids), num_passages) sampled_pids = random.sample(range(num_passages), sampled_pids) if self.verbose > 1: Run().print_main(f"# of sampled PIDs = {len(sampled_pids)} \t sampled_pids[:3] = {sampled_pids[:3]}") return set(sampled_pids) def _sample_embeddings(self, sampled_pids): local_pids = self.collection.enumerate(rank=self.rank) local_sample = [passage for pid, passage in local_pids if pid in sampled_pids] local_sample_embs, doclens = self.encoder.encode_passages(local_sample) if torch.cuda.is_available(): if torch.distributed.is_available() and torch.distributed.is_initialized(): self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda() torch.distributed.all_reduce(self.num_sample_embs) avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0 avg_doclen_est = torch.tensor([avg_doclen_est]).cuda() torch.distributed.all_reduce(avg_doclen_est) nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda() torch.distributed.all_reduce(nonzero_ranks) else: self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda() avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0 avg_doclen_est = torch.tensor([avg_doclen_est]).cuda() nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda() else: if torch.distributed.is_available() and torch.distributed.is_initialized(): self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cpu() torch.distributed.all_reduce(self.num_sample_embs) avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0 avg_doclen_est = torch.tensor([avg_doclen_est]).cpu() torch.distributed.all_reduce(avg_doclen_est) nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cpu() torch.distributed.all_reduce(nonzero_ranks) else: self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cpu() avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0 avg_doclen_est = torch.tensor([avg_doclen_est]).cpu() nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cpu() avg_doclen_est = avg_doclen_est.item() / nonzero_ranks.item() self.avg_doclen_est = avg_doclen_est Run().print(f'avg_doclen_est = {avg_doclen_est} \t len(local_sample) = {len(local_sample):,}') torch.save(local_sample_embs.half(), os.path.join(self.config.index_path_, f'sample.{self.rank}.pt')) return avg_doclen_est def _try_load_plan(self): config = self.config self.plan_path = os.path.join(config.index_path_, 'plan.json') if os.path.exists(self.plan_path): with open(self.plan_path, 'r') as f: try: plan = ujson.load(f) except Exception as e: return False if not ('num_chunks' in plan and 'num_partitions' in plan and 'num_embeddings_est' in plan and 'avg_doclen_est' in plan): return False # TODO: Verify config matches self.num_chunks = plan['num_chunks'] self.num_partitions = plan['num_partitions'] self.num_embeddings_est = plan['num_embeddings_est'] self.avg_doclen_est = plan['avg_doclen_est'] return True else: return False def _save_plan(self): if self.rank < 1: config = self.config self.plan_path = os.path.join(config.index_path_, 'plan.json') Run().print("#> Saving the indexing plan to", self.plan_path, "..") with open(self.plan_path, 'w') as f: d = {'config': config.export()} d['num_chunks'] = self.num_chunks d['num_partitions'] = self.num_partitions d['num_embeddings_est'] = self.num_embeddings_est d['avg_doclen_est'] = self.avg_doclen_est f.write(ujson.dumps(d, indent=4) + '\n') def train(self, shared_lists): if self.rank > 0: return sample, heldout = self._concatenate_and_split_sample() centroids = self._train_kmeans(sample, shared_lists) print_memory_stats(f'RANK:{self.rank}') del sample bucket_cutoffs, bucket_weights, avg_residual = self._compute_avg_residual(centroids, heldout) if self.verbose > 1: print_message(f'avg_residual = {avg_residual}') # Compute and save codec into avg_residual.pt, buckets.pt and centroids.pt codec = ResidualCodec(config=self.config, centroids=centroids, avg_residual=avg_residual, bucket_cutoffs=bucket_cutoffs, bucket_weights=bucket_weights) self.saver.save_codec(codec) def _concatenate_and_split_sample(self): print_memory_stats(f'***1*** \t RANK:{self.rank}') # TODO: Allocate a float16 array. Load the samples from disk, copy to array. sample = torch.empty(self.num_sample_embs, self.config.dim, dtype=torch.float16) offset = 0 for r in range(self.nranks): sub_sample_path = os.path.join(self.config.index_path_, f'sample.{r}.pt') sub_sample = torch.load(sub_sample_path) os.remove(sub_sample_path) endpos = offset + sub_sample.size(0) sample[offset:endpos] = sub_sample offset = endpos assert endpos == sample.size(0), (endpos, sample.size()) print_memory_stats(f'***2*** \t RANK:{self.rank}') # Shuffle and split out a 5% "heldout" sub-sample [up to 50k elements] sample = sample[torch.randperm(sample.size(0))] print_memory_stats(f'***3*** \t RANK:{self.rank}') heldout_fraction = 0.05 heldout_size = int(min(heldout_fraction * sample.size(0), 50_000)) sample, sample_heldout = sample.split([sample.size(0) - heldout_size, heldout_size], dim=0) print_memory_stats(f'***4*** \t RANK:{self.rank}') return sample, sample_heldout def _train_kmeans(self, sample, shared_lists): if self.use_gpu: torch.cuda.empty_cache() do_fork_for_faiss = False # set to True to free faiss GPU-0 memory at the cost of one more copy of `sample`. args_ = [self.config.dim, self.num_partitions, self.config.kmeans_niters] if do_fork_for_faiss: # For this to work reliably, write the sample to disk. Pickle may not handle >4GB of data. # Delete the sample file after work is done. shared_lists[0][0] = sample return_value_queue = mp.Queue() args_ = args_ + [shared_lists, return_value_queue] proc = mp.Process(target=compute_faiss_kmeans, args=args_) proc.start() centroids = return_value_queue.get() proc.join() else: args_ = args_ + [[[sample]]] centroids = compute_faiss_kmeans(*args_) centroids = torch.nn.functional.normalize(centroids, dim=-1) if self.use_gpu: centroids = centroids.half() else: centroids = centroids.float() return centroids def _compute_avg_residual(self, centroids, heldout): compressor = ResidualCodec(config=self.config, centroids=centroids, avg_residual=None) heldout_reconstruct = compressor.compress_into_codes(heldout, out_device='cuda' if self.use_gpu else 'cpu') heldout_reconstruct = compressor.lookup_centroids(heldout_reconstruct, out_device='cuda' if self.use_gpu else 'cpu') if self.use_gpu: heldout_avg_residual = heldout.cuda() - heldout_reconstruct else: heldout_avg_residual = heldout - heldout_reconstruct avg_residual = torch.abs(heldout_avg_residual).mean(dim=0).cpu() print([round(x, 3) for x in avg_residual.squeeze().tolist()]) num_options = 2 ** self.config.nbits quantiles = torch.arange(0, num_options, device=heldout_avg_residual.device) * (1 / num_options) bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options) bucket_cutoffs = heldout_avg_residual.float().quantile(bucket_cutoffs_quantiles) bucket_weights = heldout_avg_residual.float().quantile(bucket_weights_quantiles) if self.verbose > 2: print_message( f"#> Got bucket_cutoffs_quantiles = {bucket_cutoffs_quantiles} and bucket_weights_quantiles = {bucket_weights_quantiles}") print_message(f"#> Got bucket_cutoffs = {bucket_cutoffs} and bucket_weights = {bucket_weights}") return bucket_cutoffs, bucket_weights, avg_residual.mean() # EVENTAULLY: Compare the above with non-heldout sample. If too different, we can do better! # sample = sample[subsample_idxs] # sample_reconstruct = get_centroids_for(centroids, sample) # sample_avg_residual = (sample - sample_reconstruct).mean(dim=0) def index(self): ''' Encode embeddings for all passages in collection. Each embedding is converted to code (centroid id) and residual. Embeddings stored according to passage order in contiguous chunks of memory. Saved data files described below: {CHUNK#}.codes.pt: centroid id for each embedding in chunk {CHUNK#}.residuals.pt: 16-bits residual for each embedding in chunk doclens.{CHUNK#}.pt: number of embeddings within each passage in chunk ''' with self.saver.thread(): batches = self.collection.enumerate_batches(rank=self.rank) for chunk_idx, offset, passages in tqdm.tqdm(batches, disable=self.rank > 0): if self.config.resume and self.saver.check_chunk_exists(chunk_idx): if self.verbose > 2: Run().print_main(f"#> Found chunk {chunk_idx} in the index already, skipping encoding...") continue # Encode passages into embeddings with the checkpoint model embs, doclens = self.encoder.encode_passages(passages) if self.use_gpu: assert embs.dtype == torch.float16 else: assert embs.dtype == torch.float32 embs = embs.half() if self.verbose > 1: Run().print_main(f"#> Saving chunk {chunk_idx}: \t {len(passages):,} passages " f"and {embs.size(0):,} embeddings. From #{offset:,} onward.") self.saver.save_chunk(chunk_idx, offset, embs, doclens) # offset = first passage index in chunk del embs, doclens def finalize(self): ''' Aggregates and stores metadata for each chunk and the whole index Builds and saves inverse mapping from centroids to passage IDs Saved data files described below: {CHUNK#}.metadata.json: [ passage_offset, num_passages, num_embeddings, embedding_offset ] metadata.json: [ num_chunks, num_partitions, num_embeddings, avg_doclen ] inv.pid.pt: [ ivf, ivf_lengths ] ivf is an array of passage IDs for centroids 0, 1, ... ivf_length contains the number of passage IDs for each centroid ''' if self.rank > 0: return self._check_all_files_are_saved() self._collect_embedding_id_offset() self._build_ivf() self._update_metadata() def _check_all_files_are_saved(self): if self.verbose > 1: Run().print_main("#> Checking all files were saved...") success = True for chunk_idx in range(self.num_chunks): if not self.saver.check_chunk_exists(chunk_idx): success = False Run().print_main(f"#> ERROR: Could not find chunk {chunk_idx}!") #TODO: Fail here? if success: if self.verbose > 1: Run().print_main("Found all files!") def _collect_embedding_id_offset(self): passage_offset = 0 embedding_offset = 0 self.embedding_offsets = [] for chunk_idx in range(self.num_chunks): metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json') with open(metadata_path) as f: chunk_metadata = ujson.load(f) chunk_metadata['embedding_offset'] = embedding_offset self.embedding_offsets.append(embedding_offset) assert chunk_metadata['passage_offset'] == passage_offset, (chunk_idx, passage_offset, chunk_metadata) passage_offset += chunk_metadata['num_passages'] embedding_offset += chunk_metadata['num_embeddings'] with open(metadata_path, 'w') as f: f.write(ujson.dumps(chunk_metadata, indent=4) + '\n') self.num_embeddings = embedding_offset assert len(self.embedding_offsets) == self.num_chunks def _build_ivf(self): # Maybe we should several small IVFs? Every 250M embeddings, so that's every 1 GB. # It would save *memory* here and *disk space* regarding the int64. # But we'd have to decide how many IVFs to use during retrieval: many (loop) or one? # A loop seems nice if we can find a size that's large enough for speed yet small enough to fit on GPU! # Then it would help nicely for batching later: 1GB. if self.verbose > 1: Run().print_main("#> Building IVF...") codes = torch.zeros(self.num_embeddings,).long() if self.verbose > 1: print_memory_stats(f'RANK:{self.rank}') if self.verbose > 1: Run().print_main("#> Loading codes...") for chunk_idx in tqdm.tqdm(range(self.num_chunks)): offset = self.embedding_offsets[chunk_idx] chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx) codes[offset:offset+chunk_codes.size(0)] = chunk_codes assert offset+chunk_codes.size(0) == codes.size(0), (offset, chunk_codes.size(0), codes.size()) if self.verbose > 1: Run().print_main(f"Sorting codes...") print_memory_stats(f'RANK:{self.rank}') codes = codes.sort() ivf, values = codes.indices, codes.values if self.verbose > 1: print_memory_stats(f'RANK:{self.rank}') Run().print_main(f"Getting unique codes...") ivf_lengths = torch.bincount(values, minlength=self.num_partitions) assert ivf_lengths.size(0) == self.num_partitions if self.verbose > 1: print_memory_stats(f'RANK:{self.rank}') # Transforms centroid->embedding ivf to centroid->passage ivf _, _ = optimize_ivf(ivf, ivf_lengths, self.config.index_path_) def _update_metadata(self): config = self.config self.metadata_path = os.path.join(config.index_path_, 'metadata.json') if self.verbose > 1: Run().print("#> Saving the indexing metadata to", self.metadata_path, "..") with open(self.metadata_path, 'w') as f: d = {'config': config.export()} d['num_chunks'] = self.num_chunks d['num_partitions'] = self.num_partitions d['num_embeddings'] = self.num_embeddings d['avg_doclen'] = self.num_embeddings / len(self.collection) f.write(ujson.dumps(d, indent=4) + '\n') def compute_faiss_kmeans(dim, num_partitions, kmeans_niters, shared_lists, return_value_queue=None): use_gpu = torch.cuda.is_available() kmeans = faiss.Kmeans(dim, num_partitions, niter=kmeans_niters, gpu=use_gpu, verbose=True, seed=123) sample = shared_lists[0][0] sample = sample.float().numpy() kmeans.train(sample) centroids = torch.from_numpy(kmeans.centroids) print_memory_stats(f'RANK:0*') if return_value_queue is not None: return_value_queue.put(centroids) return centroids """ TODOs: 1. Consider saving/using heldout_avg_residual as a vector --- that is, using 128 averages! 2. Consider the operations with .cuda() tensors. Are all of them good for OOM? """
Memory