import os import ujson import torch import numpy as np import tqdm from colbert.search.index_loader import IndexLoader from colbert.indexing.index_saver import IndexSaver from colbert.indexing.collection_encoder import CollectionEncoder from colbert.utils.utils import lengths2offsets, print_message, dotdict, flatten from colbert.indexing.codecs.residual import ResidualCodec from colbert.indexing.utils import optimize_ivf from colbert.search.strided_tensor import StridedTensor from colbert.modeling.checkpoint import Checkpoint from colbert.utils.utils import print_message, batch from colbert.data import Collection from colbert.indexing.codecs.residual_embeddings import ResidualEmbeddings from colbert.indexing.codecs.residual_embeddings_strided import ( ResidualEmbeddingsStrided, ) from colbert.indexing.utils import optimize_ivf # For testing writing into new chunks, can set DEFAULT_CHUNKSIZE smaller (e.g. 1 or 2) DEFAULT_CHUNKSIZE = 25000 class IndexUpdater: """ IndexUpdater takes in a searcher and adds/remove passages from the searcher. A checkpoint for passage-encoding must be provided for adding passages. IndexUpdater can also persist the change of passages to the index on disk. Sample usage: index_updater = IndexUpdater(config, searcher, checkpoint) added_pids = index_updater.add(passages) # all passages added to searcher with their pids returned index_updater.remove(pids) # all pid within pids removed from searcher searcher.search() # the search now reflects the added & removed passages index_updater.persist_to_disk() # added & removed passages persisted to index on disk searcher.Searcher(index, config) # if we reload the searcher now from disk index, the changes we made persists """ def __init__(self, config, searcher, checkpoint=None): self.config = config self.searcher = searcher self.index_path = searcher.index self.has_checkpoint = False if checkpoint: self.has_checkpoint = True self.checkpoint = Checkpoint(checkpoint, config) self.encoder = CollectionEncoder(config, self.checkpoint) self._load_disk_ivf() # variables to track removal / append of passages self.removed_pids = [] self.first_new_emb = torch.sum(self.searcher.ranker.doclens).item() self.first_new_pid = len(self.searcher.ranker.doclens) def remove(self, pids): """ Input: pids: list(int) Return: None Removes a list of pids from the searcher, these pids will no longer apppear in future searches with this searcher to erase passage data from index, call persist_to_disk() after calling remove() """ invalid_pids = self._check_pids(pids) if invalid_pids: raise ValueError("Invalid PIDs", invalid_pids) print_message(f"#> Removing pids: {pids}...") self._remove_pid_from_ivf(pids) self.removed_pids.extend(pids) def create_embs_and_doclens( self, passages, embs_path="embs.pt", doclens_path="doclens.pt", persist=False ): # Extend doclens and embs of self.searcher.ranker embs, doclens = self.encoder.encode_passages(passages) compressed_embs = self.searcher.ranker.codec.compress(embs) if persist: torch.save(compressed_embs, embs_path) torch.save(doclens, doclens_path) return compressed_embs, doclens def update_searcher(self, compressed_embs, doclens, curr_pid): # Update searcher # NOTE: For codes and residuals, the tensors end with padding of length 512, # hence we concatenate the new appendage in front of the padding self.searcher.ranker.embeddings.codes = torch.cat( ( self.searcher.ranker.embeddings.codes[:-512], compressed_embs.codes, self.searcher.ranker.embeddings.codes[-512:], ) ) self.searcher.ranker.embeddings.residuals = torch.cat( ( self.searcher.ranker.embeddings.residuals[:-512], compressed_embs.residuals, self.searcher.ranker.embeddings.residuals[-512:], ), dim=0, ) self.searcher.ranker.doclens = torch.cat( (self.searcher.ranker.doclens, torch.tensor(doclens)) ) # Build partitions for each pid and update IndexUpdater's current ivf start = 0 ivf = self.curr_ivf.tolist() ivf_lengths = self.curr_ivf_lengths.tolist() for doclen in doclens: end = start + doclen codes = compressed_embs.codes[start:end] partitions, _ = self._build_passage_partitions(codes) ivf, ivf_lengths = self._add_pid_to_ivf(partitions, curr_pid, ivf, ivf_lengths) start = end curr_pid += 1 assert start == sum(doclens) # Replace the current ivf with new_ivf self.curr_ivf = torch.tensor(ivf, dtype=self.curr_ivf.dtype) self.curr_ivf_lengths = torch.tensor(ivf_lengths, dtype=self.curr_ivf_lengths.dtype) # Update new ivf in searcher new_ivf_tensor = StridedTensor( self.curr_ivf, self.curr_ivf_lengths, use_gpu=False ) assert new_ivf_tensor != self.searcher.ranker.ivf self.searcher.ranker.ivf = new_ivf_tensor # Rebuild StridedTensor within searcher self.searcher.ranker.set_embeddings_strided() def add(self, passages): """ Input: passages: list(string) Output: passage_ids: list(int) Adds new passages to the searcher, to add passages to the index, call persist_to_disk() after calling add() """ if not self.has_checkpoint: raise ValueError( "No checkpoint was provided at IndexUpdater initialization." ) # Find pid for the first added passage start_pid = len(self.searcher.ranker.doclens) curr_pid = start_pid compressed_embs, doclens = self.create_embs_and_doclens(passages) self.update_searcher(compressed_embs, doclens, curr_pid) print_message(f"#> Added {len(passages)} passages from pid {start_pid}.") new_pids = list(range(start_pid, start_pid + len(passages))) return new_pids def persist_to_disk(self): """ Persist all previous stored changes in IndexUpdater to index on disk, changes include all calls to IndexUpdater.remove() and IndexUpdater.add() before persist_to_disk() is called. """ print_message("#> Persisting index changes to disk") # Propagate all removed passages to disk self._load_metadata() for pid in self.removed_pids: self._remove_passage_from_disk(pid) # Propagate all added passages to disk # Rationale: keep record of all added passages in IndexUpdater.searcher, # divide passages into chunks and create / write chunks here self._load_metadata() # Reload after removal # Calculate avg number of passages per chunk curr_num_chunks = self.metadata["num_chunks"] last_chunk_metadata = self._load_chunk_metadata(curr_num_chunks - 1) if curr_num_chunks == 1: avg_chunksize = DEFAULT_CHUNKSIZE else: avg_chunksize = last_chunk_metadata["passage_offset"] / ( curr_num_chunks - 1 ) print_message(f"#> Current average chunksize is: {avg_chunksize}.") # Calculate number of additional passages we can write to the last chunk last_chunk_capacity = max( 0, avg_chunksize - last_chunk_metadata["num_passages"] ) print_message( f"#> The last chunk can hold {last_chunk_capacity} additional passages." ) # Find the first and last passages to be persisted pid_start = self.first_new_pid emb_start = self.first_new_emb pid_last = len(self.searcher.ranker.doclens) emb_last = ( emb_start + torch.sum(self.searcher.ranker.doclens[pid_start:]).item() ) # First populate the last chunk if last_chunk_capacity > 0: pid_end = min(pid_last, pid_start + last_chunk_capacity) emb_end = ( emb_start + torch.sum(self.searcher.ranker.doclens[pid_start:pid_end]).item() ) # Write to last chunk self._write_to_last_chunk(pid_start, pid_end, emb_start, emb_end) pid_start = pid_end emb_start = emb_end # Then create new chunks to hold the remaining added passages while pid_start < pid_last: pid_end = min(pid_last, pid_start + avg_chunksize) emb_end = ( emb_start + torch.sum(self.searcher.ranker.doclens[pid_start:pid_end]).item() ) # Write new chunk with id = curr_num_chunks self._write_to_new_chunk( curr_num_chunks, pid_start, pid_end, emb_start, emb_end ) curr_num_chunks += 1 pid_start = pid_end emb_start = emb_end assert pid_start == pid_last assert emb_start == emb_last # Update metadata print_message("#> Updating metadata for added passages...") self.metadata["num_chunks"] = curr_num_chunks self.metadata["num_embeddings"] += torch.sum( self.searcher.ranker.doclens ).item() metadata_path = os.path.join(self.index_path, "metadata.json") with open(metadata_path, "w") as output_metadata: ujson.dump(self.metadata, output_metadata) # Save current IVF to disk optimized_ivf_path = os.path.join(self.index_path, "ivf.pid.pt") torch.save((self.curr_ivf, self.curr_ivf_lengths), optimized_ivf_path) print_message(f"#> Persisted updated IVF to {optimized_ivf_path}") self.removed_pids = [] self.first_new_emb = torch.sum(self.searcher.ranker.doclens).item() self.first_new_pid = len(self.searcher.ranker.doclens) # HELPER FUNCTIONS BELOW def _load_disk_ivf(self): print_message(f"#> Loading IVF...") if os.path.exists(os.path.join(self.index_path, "ivf.pid.pt")): ivf, ivf_lengths = torch.load( os.path.join(self.index_path, "ivf.pid.pt"), map_location="cpu" ) else: assert os.path.exists(os.path.join(self.index_path, "ivf.pt")) ivf, ivf_lengths = torch.load( os.path.join(self.index_path, "ivf.pt"), map_location="cpu" ) ivf, ivf_lengths = optimize_ivf(ivf, ivf_lengths, self.index_path) self.curr_ivf = ivf self.curr_ivf_lengths = ivf_lengths def _load_metadata(self): with open(os.path.join(self.index_path, "metadata.json")) as f: self.metadata = ujson.load(f) def _load_chunk_doclens(self, chunk_idx): doclens = [] print_message("#> Loading doclens...") with open(os.path.join(self.index_path, f"doclens.{chunk_idx}.json")) as f: chunk_doclens = ujson.load(f) doclens.extend(chunk_doclens) doclens = torch.tensor(doclens) return doclens def _load_chunk_codes(self, chunk_idx): codes_path = os.path.join(self.index_path, f"{chunk_idx}.codes.pt") return torch.load(codes_path, map_location="cpu") def _load_chunk_residuals(self, chunk_idx): residuals_path = os.path.join(self.index_path, f"{chunk_idx}.residuals.pt") return torch.load(residuals_path, map_location="cpu") def _load_chunk_metadata(self, chunk_idx): with open(os.path.join(self.index_path, f"{chunk_idx}.metadata.json")) as f: chunk_metadata = ujson.load(f) return chunk_metadata def _get_chunk_idx(self, pid): for i in range(self.metadata["num_chunks"]): chunk_metadata = self._load_chunk_metadata(i) if ( chunk_metadata["passage_offset"] <= pid and chunk_metadata["passage_offset"] + chunk_metadata["num_passages"] > pid ): return i raise ValueError("Passage ID out of range") def _check_pids(self, pids): invalid_pids = [] for pid in pids: if pid < 0 or pid >= len(self.searcher.ranker.doclens): invalid_pids.append(pid) return invalid_pids def _remove_pid_from_ivf(self, pids): # Helper function for IndexUpdater.remove() new_ivf = [] new_ivf_lengths = [] runner = 0 pids = set(pids) # Construct mask of where pids to be removed appear in ivf mask = torch.isin(self.curr_ivf, torch.tensor(list(pids))) indices = mask.nonzero() # Calculate end-indices of each centroid section in ivf section_end_indices = [] c = 0 for length in self.curr_ivf_lengths.tolist(): c += length section_end_indices.append(c) # Record the number of pids removed from each centroid section removed_len = [0 for _ in range(len(section_end_indices))] j = 0 for ind in indices: while ind >= section_end_indices[j]: j += 1 removed_len[j] += 1 # Update changes new_ivf = torch.masked_select(self.curr_ivf, ~mask) new_ivf_lengths = self.curr_ivf_lengths - torch.tensor(removed_len) new_ivf_tensor = StridedTensor(new_ivf, new_ivf_lengths, use_gpu=False) assert new_ivf_tensor != self.searcher.ranker.ivf self.searcher.ranker.ivf = new_ivf_tensor self.curr_ivf = new_ivf self.curr_ivf_lengths = new_ivf_lengths def _build_passage_partitions(self, codes): # Helper function for IndexUpdater.add() # Return a list of ordered, unique centroid ids from codes of a passage codes = codes.sort() ivf, values = codes.indices, codes.values partitions, ivf_lengths = values.unique_consecutive(return_counts=True) return partitions, ivf_lengths def _add_pid_to_ivf(self, partitions, pid, old_ivf, old_ivf_lengths): """ Helper function for IndexUpdater.add() Input: partitions: list(int), centroid ids of the passage pid: int, passage id Output: None Adds the pid of new passage into the ivf. """ new_ivf = [] new_ivf_lengths = [] partitions_runner = 0 ivf_runner = 0 for i in range(len(old_ivf_lengths)): # First copy existing partition pids to new ivf new_ivf.extend(old_ivf[ivf_runner : ivf_runner + old_ivf_lengths[i]]) new_ivf_lengths.append(old_ivf_lengths[i]) ivf_runner += old_ivf_lengths[i] # Add pid if partition_index i is in the passage's partitions if ( partitions_runner < len(partitions) and i == partitions[partitions_runner] ): new_ivf.append(pid) new_ivf_lengths[-1] += 1 partitions_runner += 1 assert ivf_runner == len(old_ivf) assert sum(new_ivf_lengths) == len(new_ivf) return new_ivf, new_ivf_lengths def _write_to_last_chunk(self, pid_start, pid_end, emb_start, emb_end): # Helper function for IndexUpdater.persist_to_disk() print_message(f"#> Writing {pid_end - pid_start} passages to the last chunk...") num_chunks = self.metadata["num_chunks"] # Append to current last chunk curr_embs = ResidualEmbeddings.load(self.index_path, num_chunks - 1) curr_embs.codes = torch.cat( (curr_embs.codes, self.searcher.ranker.embeddings.codes[emb_start:emb_end]) ) curr_embs.residuals = torch.cat( ( curr_embs.residuals, self.searcher.ranker.embeddings.residuals[emb_start:emb_end], ) ) path_prefix = os.path.join(self.index_path, f"{num_chunks - 1}") curr_embs.save(path_prefix) # Update doclen of last chunk curr_doclens = self._load_chunk_doclens(num_chunks - 1).tolist() curr_doclens.extend(self.searcher.ranker.doclens.tolist()[pid_start:pid_end]) doclens_path = os.path.join(self.index_path, f"doclens.{num_chunks - 1}.json") with open(doclens_path, "w") as output_doclens: ujson.dump(curr_doclens, output_doclens) # Update metadata of last chunk chunk_metadata = self._load_chunk_metadata(num_chunks - 1) chunk_metadata["num_passages"] += pid_end - pid_start chunk_metadata["num_embeddings"] += emb_end - emb_start chunk_metadata_path = os.path.join( self.index_path, f"{num_chunks - 1}.metadata.json" ) with open(chunk_metadata_path, "w") as output_chunk_metadata: ujson.dump(chunk_metadata, output_chunk_metadata) def _write_to_new_chunk(self, chunk_idx, pid_start, pid_end, emb_start, emb_end): # Helper function for IndexUpdater.persist_to_disk() # Save embeddings to new chunk curr_embs = ResidualEmbeddings( self.searcher.ranker.embeddings.codes[emb_start:emb_end], self.searcher.ranker.embeddings.residuals[emb_start:emb_end], ) path_prefix = os.path.join(self.index_path, f"{chunk_idx}") curr_embs.save(path_prefix) # Create doclen json file for new chunk curr_doclens = self.searcher.ranker.doclens.tolist()[pid_start:pid_end] doclens_path = os.path.join(self.index_path, f"doclens.{chunk_idx}.json") with open(doclens_path, "w+") as output_doclens: ujson.dump(curr_doclens, output_doclens) # Create metadata json file for new chunk chunk_metadata = { "passage_offset": pid_start, "num_passages": pid_end - pid_start, "embedding_offset": emb_start, "num_embeddings": emb_end - emb_start, } chunk_metadata_path = os.path.join( self.index_path, f"{chunk_idx}.metadata.json" ) with open(chunk_metadata_path, "w+") as output_chunk_metadata: ujson.dump(chunk_metadata, output_chunk_metadata) def _remove_passage_from_disk(self, pid): # Helper function for IndexUpdater.persist_to_disk() chunk_idx = self._get_chunk_idx(pid) chunk_metadata = self._load_chunk_metadata(chunk_idx) i = pid - chunk_metadata["passage_offset"] doclens = self._load_chunk_doclens(chunk_idx) codes, residuals = ( self._load_chunk_codes(chunk_idx), self._load_chunk_residuals(chunk_idx), ) # Remove embeddings from codes and residuals start = sum(doclens[:i]) end = start + doclens[i] codes = torch.cat((codes[:start], codes[end:])) residuals = torch.cat((residuals[:start], residuals[end:])) codes_path = os.path.join(self.index_path, f"{chunk_idx}.codes.pt") residuals_path = os.path.join(self.index_path, f"{chunk_idx}.residuals.pt") torch.save(codes, codes_path) torch.save(residuals, residuals_path) # Change doclen for passage to 0 doclens = doclens.tolist() doclen_to_remove = doclens[i] doclens[i] = 0 doclens_path = os.path.join(self.index_path, f"doclens.{chunk_idx}.json") with open(doclens_path, "w") as output_doclens: ujson.dump(doclens, output_doclens) # Modify chunk_metadata['num_embeddings'] for chunk_idx chunk_metadata["num_embeddings"] -= doclen_to_remove chunk_metadata_path = os.path.join( self.index_path, f"{chunk_idx}.metadata.json" ) with open(chunk_metadata_path, "w") as output_chunk_metadata: ujson.dump(chunk_metadata, output_chunk_metadata) # Modify chunk_metadata['embedding_offset'] for all later chunks (minus num_embs_removed) for idx in range(chunk_idx + 1, self.metadata["num_chunks"]): metadata = self._load_chunk_metadata(idx) metadata["embedding_offset"] -= doclen_to_remove metadata_path = os.path.join(self.index_path, f"{idx}.metadata.json") with open(metadata_path, "w") as output_chunk_metadata: ujson.dump(metadata, output_chunk_metadata) # Modify num_embeddings in overall metadata (minus num_embs_removed) self.metadata["num_embeddings"] -= doclen_to_remove metadata_path = os.path.join(self.index_path, "metadata.json") with open(metadata_path, "w") as output_metadata: ujson.dump(self.metadata, output_metadata)
Memory