import os import argparse import torch from tqdm import tqdm import ujson import shutil def main(args): in_file = args.input out_file = args.output # Get num_chunks from metadata filepath = os.path.join(in_file, 'metadata.json') with open(filepath, 'r') as f: metadata = ujson.load(f) num_chunks = metadata['num_chunks'] print(f"Num_chunks = {num_chunks}") # Create output dir if not already created if not os.path.exists(out_file): os.makedirs(out_file) ## Coalesce doclens ## print("Coalescing doclens files...") temp = [] # read into one large list for i in tqdm(range(num_chunks)): filepath = os.path.join(in_file, f'doclens.{i}.json') with open(filepath, 'r') as f: chunk = ujson.load(f) temp.extend(chunk) # write to output json filepath = os.path.join(out_file, 'doclens.0.json') with open(filepath, 'w') as f: ujson.dump(temp, f) ## Coalesce codes ## print("Coalescing codes files...") temp = torch.empty(0, dtype=torch.int32) # read into one large tensor for i in tqdm(range(num_chunks)): filepath = os.path.join(in_file, f'{i}.codes.pt') chunk = torch.load(filepath) temp = torch.cat((temp, chunk)) # save length of index index_len = temp.size()[0] # write to output tensor filepath = os.path.join(out_file, '0.codes.pt') torch.save(temp, filepath) ## Coalesce residuals ## print("Coalescing residuals files...") # Allocate all the memory needed in the beginning. Starting from torch.empty() and concatenating repeatedly results in excessive memory use and is much much slower. temp = torch.zeros(((metadata['num_embeddings'], int(metadata['config']['dim'] * metadata['config']['nbits'] // 8))), dtype=torch.uint8) cur_offset = 0 # read into one large tensor for i in tqdm(range(num_chunks)): filepath = os.path.join(in_file, f'{i}.residuals.pt') chunk = torch.load(filepath) temp[cur_offset : cur_offset + len(chunk):] = chunk cur_offset += len(chunk) print("Saving residuals to output directory (this may take a few minutes)...") # write to output tensor filepath = os.path.join(out_file, '0.residuals.pt') torch.save(temp, filepath) # save metadata.json metadata['num_chunks'] = 1 filepath = os.path.join(out_file, 'metadata.json') with open(filepath, 'w') as f: ujson.dump(metadata, f, indent=4) metadata_0 = {} metadata_0["num_embeddings"] = metadata["num_embeddings"] metadata_0["passage_offset"] = 0 metadata_0["embedding_offset"] = 0 filepath = os.path.join(in_file, str(num_chunks-1) + '.metadata.json') with open(filepath, 'r') as f: metadata_last = ujson.load(f) metadata_0["num_passages"] = int(metadata_last["num_passages"]) + int(metadata_last["passage_offset"]) filepath = os.path.join(out_file, '0.metadata.json') with open(filepath, 'w') as f: ujson.dump(metadata_0, f, indent=4) filepath = os.path.join(in_file, 'plan.json') with open(filepath, 'r') as f: plan = ujson.load(f) plan['num_chunks'] = 1 filepath = os.path.join(out_file, 'plan.json') with open(filepath, 'w') as f: ujson.dump(plan, f, indent=4) other_files = ['avg_residual.pt', 'buckets.pt', 'centroids.pt', 'ivf.pt', 'ivf.pid.pt'] for filename in other_files: filepath = os.path.join(in_file, filename) if os.path.isfile(filepath): shutil.copy(filepath, out_file) print("Saved index to output directory {}.".format(out_file)) print("Number of embeddings = {}".format(index_len)) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Coalesce multi-file index into a single file.") parser.add_argument( "--input", type=str, required=True, help="Path to input index directory" ) parser.add_argument( "--output", type=str, required=True, help="Path to output index directory" ) args = parser.parse_args() main(args)
Memory