import logging import multiprocessing from multiprocessing.connection import Connection import multiprocessing.context import time from typing import Generator, Callable, List, Tuple, cast from uuid import UUID from hypothesis import given import hypothesis.strategies as st import pytest import chromadb from chromadb.api import ClientAPI, ServerAPI from chromadb.config import Settings, System from chromadb.segment import VectorReader from chromadb.segment.impl.manager.local import LocalSegmentManager import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants from strategies import hashing_embedding_function from chromadb.test.property.test_embeddings import ( EmbeddingStateMachineStates, trace, EmbeddingStateMachineBase, ) from hypothesis.stateful import ( run_state_machine_as_test, rule, precondition, initialize, MultipleResults, ) import os import shutil import tempfile from chromadb.api.client import Client as ClientCreator from chromadb.utils.embedding_functions import DefaultEmbeddingFunction import numpy as np CreatePersistAPI = Callable[[], ServerAPI] configurations = [ Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", allow_reset=True, is_persistent=True, persist_directory=tempfile.mkdtemp(), ), ] @pytest.fixture(scope="module", params=configurations) def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: configuration = request.param save_path = configuration.persist_directory # Create if it doesn't exist if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True) yield configuration # Remove if it exists if os.path.exists(save_path): shutil.rmtree(save_path, ignore_errors=True) collection_st = st.shared( strategies.collections( with_hnsw_params=True, with_persistent_hnsw_params=st.just(True), # Makes it more likely to find persist-related bugs (by default these are set to 2000). # Lower values make it more likely that a test will trigger a persist to disk. max_hnsw_batch_size=10, max_hnsw_sync_threshold=10, ), key="coll", ) @st.composite def collection_and_recordset_strategy( draw: st.DrawFn, ) -> Tuple[strategies.Collection, strategies.RecordSet]: collection = draw( strategies.collections( with_hnsw_params=True, with_persistent_hnsw_params=st.just(True), # Makes it more likely to find persist-related bugs (by default these are set to 2000). max_hnsw_batch_size=10, max_hnsw_sync_threshold=10, ) ) recordset = draw(strategies.recordsets(st.just(collection))) return collection, recordset @given( collection_and_recordset_strategies=st.lists( collection_and_recordset_strategy(), min_size=1, unique_by=(lambda x: x[0].name, lambda x: x[0].name), ) ) def test_persist( settings: Settings, collection_and_recordset_strategies: List[ Tuple[strategies.Collection, strategies.RecordSet] ], ) -> None: system_1 = System(settings) system_1.start() client_1 = ClientCreator.from_system(system_1) client_1.reset() for ( collection_strategy, recordset_strategy, ) in collection_and_recordset_strategies: coll = client_1.create_collection( name=collection_strategy.name, metadata=collection_strategy.metadata, # type: ignore[arg-type] embedding_function=collection_strategy.embedding_function, ) coll.add(**recordset_strategy) # type: ignore[arg-type] invariants.count(coll, recordset_strategy) invariants.metadatas_match(coll, recordset_strategy) invariants.documents_match(coll, recordset_strategy) invariants.ids_match(coll, recordset_strategy) invariants.ann_accuracy( coll, recordset_strategy, embedding_function=collection_strategy.embedding_function, ) system_1.stop() del client_1 del system_1 system_2 = System(settings) system_2.start() client_2 = ClientCreator.from_system(system_2) for ( collection_strategy, recordset_strategy, ) in collection_and_recordset_strategies: coll = client_2.get_collection( name=collection_strategy.name, embedding_function=collection_strategy.embedding_function, ) invariants.count(coll, recordset_strategy) invariants.metadatas_match(coll, recordset_strategy) invariants.documents_match(coll, recordset_strategy) invariants.ids_match(coll, recordset_strategy) invariants.ann_accuracy( coll, recordset_strategy, embedding_function=collection_strategy.embedding_function, ) system_2.stop() del client_2 del system_2 def test_sync_threshold(settings: Settings) -> None: system = System(settings) system.start() client = ClientCreator.from_system(system) collection = client.create_collection( name="test", metadata={"hnsw:batch_size": 3, "hnsw:sync_threshold": 3} ) manager = system.instance(LocalSegmentManager) segment = manager.get_segment(collection.id, VectorReader) def get_index_last_modified_at() -> float: # Time resolution on Windows can be up to 10ms time.sleep(0.1) try: return os.path.getmtime(segment._get_metadata_file()) # type: ignore[attr-defined] except FileNotFoundError: return -1 last_modified_at = get_index_last_modified_at() collection.add(ids=["1", "2"], embeddings=[[1.0], [2.0]]) # type: ignore[arg-type] # Should not have yet persisted assert get_index_last_modified_at() == last_modified_at last_modified_at = get_index_last_modified_at() # Now there's 3 additions, and the sync threshold is 3... collection.add(ids=["3"], embeddings=[[3.0]]) # type: ignore[arg-type] # ...so it should have persisted assert get_index_last_modified_at() > last_modified_at last_modified_at = get_index_last_modified_at() # The same thing should happen with upserts collection.upsert(ids=["1", "2", "3"], embeddings=[[1.0], [2.0], [3.0]]) # type: ignore[arg-type] # Should have persisted assert get_index_last_modified_at() > last_modified_at last_modified_at = get_index_last_modified_at() # Mixed usage should also trigger persistence collection.add(ids=["4"], embeddings=[[4.0]]) # type: ignore[arg-type] collection.upsert(ids=["1", "2"], embeddings=[[1.0], [2.0]]) # type: ignore[arg-type] # Should have persisted assert get_index_last_modified_at() > last_modified_at last_modified_at = get_index_last_modified_at() # Invalid updates should also trigger persistence collection.add(ids=["5"], embeddings=[[5.0]]) # type: ignore[arg-type] collection.add(ids=["1", "2"], embeddings=[[1.0], [2.0]]) # type: ignore[arg-type] # Should have persisted assert get_index_last_modified_at() > last_modified_at last_modified_at = get_index_last_modified_at() def load_and_check( settings: Settings, collection_name: str, record_set: strategies.RecordSet, conn: Connection, ) -> None: try: system = System(settings) system.start() client = ClientCreator.from_system(system) coll = client.get_collection( name=collection_name, embedding_function=strategies.not_implemented_embedding_function(), # type: ignore[arg-type] ) invariants.count(coll, record_set) invariants.metadatas_match(coll, record_set) invariants.documents_match(coll, record_set) invariants.ids_match(coll, record_set) invariants.ann_accuracy(coll, record_set) system.stop() except Exception as e: conn.send(e) raise e def get_multiprocessing_context(): # type: ignore[no-untyped-def] try: # Run the invariants in a new process to bypass any shared state/caching (which would defeat the purpose of the test) # (forkserver is used because it's much faster than spawn—it will spawn a new, minimal singleton process and then fork that singleton) ctx = multiprocessing.get_context("forkserver") # This is like running `import chromadb` in the single process that is forked rather than importing it in each forked process. # Gives a ~3x speedup since importing chromadb is fairly expensive. ctx.set_forkserver_preload(["chromadb"]) return ctx except Exception: # forkserver/fork is not available on Windows return multiprocessing.get_context("spawn") class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): persist = "persist" MIN_STATE_CHANGES_BEFORE_PERSIST = 5 class PersistEmbeddingsStateMachine(EmbeddingStateMachineBase): def __init__(self, client: ClientAPI, settings: Settings): self.client = client self.settings = settings self.min_state_changes_left_before_persisting = MIN_STATE_CHANGES_BEFORE_PERSIST self.client.reset() super().__init__(self.client) @initialize(collection=collection_st) # type: ignore def initialize(self, collection: strategies.Collection): self.client.reset() self.collection = self.client.create_collection( name=collection.name, metadata=collection.metadata, # type: ignore[arg-type] embedding_function=collection.embedding_function, ) self.embedding_function = collection.embedding_function trace("init") self.on_state_change(EmbeddingStateMachineStates.initialize) self.record_set_state = strategies.StateMachineRecordSet( ids=[], metadatas=[], documents=[], embeddings=[] ) @precondition( lambda self: len(self.record_set_state["ids"]) >= 1 and self.min_state_changes_left_before_persisting <= 0 ) @rule() def persist(self) -> None: self.on_state_change(PersistEmbeddingsStateMachineStates.persist) collection_name = self.collection.name conn1, conn2 = multiprocessing.Pipe() ctx = get_multiprocessing_context() # type: ignore[no-untyped-call] p = ctx.Process( target=load_and_check, args=(self.settings, collection_name, self.record_set_state, conn2), ) p.start() p.join() if conn1.poll(): e = conn1.recv() raise e p.close() def on_state_change(self, new_state: str) -> None: super().on_state_change(new_state) if new_state == PersistEmbeddingsStateMachineStates.persist: self.min_state_changes_left_before_persisting = ( MIN_STATE_CHANGES_BEFORE_PERSIST ) else: self.min_state_changes_left_before_persisting -= 1 def teardown(self) -> None: self.client.reset() def test_persist_embeddings_state( caplog: pytest.LogCaptureFixture, settings: Settings ) -> None: caplog.set_level(logging.ERROR) client = chromadb.Client(settings) run_state_machine_as_test( lambda: PersistEmbeddingsStateMachine(settings=settings, client=client), ) # type: ignore def test_delete_less_than_k( caplog: pytest.LogCaptureFixture, settings: Settings ) -> None: client = chromadb.Client(settings) state = PersistEmbeddingsStateMachine(settings=settings, client=client) state.initialize( collection=strategies.Collection( name="A00", metadata={ "hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128, "hnsw:sync_threshold": 3, "hnsw:batch_size": 3, }, embedding_function=None, id=UUID("2d3eddc7-2314-45f4-a951-47a9a8e099d2"), dimension=2, dtype=np.float16, known_metadata_keys={}, known_document_keywords=[], has_documents=False, has_embeddings=True, ) ) state.ann_accuracy() state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() (embedding_ids_0,) = state.add_embeddings(record_set={"ids": ["0"], "embeddings": [[0.09765625, 0.430419921875]], "metadatas": [None], "documents": None}) # type: ignore state.ann_accuracy() # recall: 1.0, missing 0 out of 1, accuracy threshold 1e-06 state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() embedding_ids_1, embedding_ids_2 = state.add_embeddings(record_set={"ids": ["1", "2"], "embeddings": [[0.20556640625, 0.08978271484375], [-0.1527099609375, 0.291748046875]], "metadatas": [None, None], "documents": None}) # type: ignore state.ann_accuracy() # recall: 1.0, missing 0 out of 3, accuracy threshold 1e-06 state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() state.delete_by_ids(ids=[embedding_ids_2]) state.ann_accuracy() state.teardown() # Ideally this scenario would be exercised by Hypothesis, but most runs don't seem to trigger this particular state. def test_delete_add_after_persist(settings: Settings) -> None: client = chromadb.Client(settings) state = PersistEmbeddingsStateMachine(settings=settings, client=client) state.initialize( collection=strategies.Collection( name="A00", metadata={ "hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128, # Important: both batch_size and sync_threshold are 3 "hnsw:batch_size": 3, "hnsw:sync_threshold": 3, }, embedding_function=DefaultEmbeddingFunction(), # type: ignore[arg-type] id=UUID("0851f751-2f11-4424-ab23-4ae97074887a"), dimension=2, dtype=None, known_metadata_keys={}, known_document_keywords=[], has_documents=False, has_embeddings=True, ) ) state.add_embeddings( record_set={ # Add 3 records to hit the batch_size and sync_threshold "ids": ["0", "1", "2"], "embeddings": [[0, 0], [0, 0], [0, 0]], "metadatas": [None, None, None], "documents": None, } ) # Delete and then re-add record state.delete_by_ids(ids=["0"]) state.add_embeddings( record_set={ "ids": ["0"], "embeddings": [[1, 1]], "metadatas": [None], "documents": None, } ) # At this point, the changes above are not fully persisted state.fields_match() def test_batch_size_less_than_sync_with_duplicate_adds_results_in_skipped_seq_ids( caplog: pytest.LogCaptureFixture, settings: Settings ) -> None: # NOTE(hammadb) this test was autogenerate by hypothesis and added here to ensure that the test is run # in the future. It tests a case where the max seq id was incorrect in response to the same # id being added multiple times in a bathc. client = chromadb.Client(settings) state = PersistEmbeddingsStateMachine(settings=settings, client=client) state.initialize( collection=strategies.Collection( name="JqzMs4pPm14c\n", metadata={ "hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128, "hnsw:sync_threshold": 9, "hnsw:batch_size": 7, }, embedding_function=hashing_embedding_function(dim=92, dtype=np.float64), id=UUID("45c5c816-0a90-4293-8d01-4325ff860040"), dimension=92, dtype=np.float64, known_metadata_keys={}, known_document_keywords=[], has_documents=False, has_embeddings=True, ) ) state.ann_accuracy() state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() ( embedding_ids_0, embedding_ids_1, embedding_ids_2, embedding_ids_3, embedding_ids_4, embedding_ids_5, embedding_ids_6, ) = cast( MultipleResults[str], state.add_embeddings( record_set={ "ids": ["N", "e8r6", "4", "Yao", "qFjA2c", "jHCv", "2"], "embeddings": [ [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0], ], "metadatas": None, "documents": None, } ), ) state.ann_accuracy() # recall: 1.0, missing 0 out of 7, accuracy threshold 1e-06 state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() print("\n\n") (_) = state.add_embeddings( record_set={ "ids": ["MVu393QTc"], "embeddings": [[7.0, 7.0, 7.0]], "metadatas": None, "documents": None, } ) state.ann_accuracy() # recall: 1.0, missing 0 out of 8, accuracy threshold 1e-06 state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() ( _, _, _, _, embedding_ids_12, _, _, _, _, embedding_ids_17, embedding_ids_18, _, _, _, embedding_ids_22, _, _, ) = cast( MultipleResults[str], state.add_embeddings( record_set={ "ids": [ "CyF0Mk-", "q_Fwu", "2D2sQSFogDgPLkcfT", "SrwuQHQ6w4f51qWr2enLPQw8uKYs1", "G", "wdzt", "5W", "8tpsn", "fJbV7z", "5", "V", "1iFkoJX", "Zw4u", "Fc", "7", "vEEwrP", "Yf", ], "embeddings": [ [8.0, 8.0, 8.0], [9.0, 9.0, 9.0], [10.0, 10.0, 10.0], [11.0, 11.0, 11.0], [12.0, 12.0, 12.0], [13.0, 13.0, 13.0], [14.0, 14.0, 14.0], [15.0, 15.0, 15.0], [16.0, 16.0, 16.0], [17.0, 17.0, 17.0], [18.0, 18.0, 18.0], [19.0, 19.0, 19.0], [20.0, 20.0, 20.0], [21.0, 21.0, 21.0], [22.0, 22.0, 22.0], [23.0, 23.0, 23.0], [24.0, 24.0, 24.0], ], "metadatas": None, "documents": None, } ), ) state.ann_accuracy() state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() state.add_embeddings( record_set={ "ids": ["0", "df_RWhR0HelOcv"], "embeddings": [[25.0, 25.0, 25.0], [26.0, 26.0, 26.0]], "metadatas": [None, None], "documents": None, } ) state.ann_accuracy() state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() state.add_embeddings( record_set={ "ids": ["3R", "9_", "44u", "3B", "MZCXZDS", "Uelx"], "embeddings": [ [27.0, 27.0, 27.0], [28.0, 28.0, 28.0], [29.0, 29.0, 29.0], [30.0, 30.0, 30.0], [31.0, 31.0, 31.0], [32.0, 32.0, 32.0], ], "metadatas": None, "documents": None, } ) state.ann_accuracy() state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() state.persist() state.ann_accuracy() state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() state.add_embeddings( record_set={ "ids": "YlVm", "embeddings": [[33.0, 33.0, 33.0]], "metadatas": None, "documents": None, } ) state.ann_accuracy() # recall: 1.0, missing 0 out of 34, accuracy threshold 1e-06 state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() state.add_embeddings( record_set={ "ids": ["Rk1", "TPL"], "embeddings": [[34.0, 34.0, 34.0], [35.0, 35.0, 35.0]], "metadatas": [None, None], "documents": None, } ) state.ann_accuracy() # recall: 1.0, missing 0 out of 36, accuracy threshold 1e-06 state.count() state.fields_match() state.log_size_below_max() state.no_duplicates() state.add_embeddings( record_set={ "ids": [ "CyF0Mk-", "q_Fwu", "2D2sQSFogDgPLkcfT", "SrwuQHQ6w4f51qWr2enLPQw8uKYs1", embedding_ids_12, "wdzt", "5W", "8tpsn", "fJbV7z", embedding_ids_17, embedding_ids_18, "1iFkoJX", "Zw4u", "Fc", embedding_ids_22, "vEEwrP", "Yf", ], "embeddings": [ [8.0, 8.0, 8.0], [9.0, 9.0, 9.0], [10.0, 10.0, 10.0], [11.0, 11.0, 11.0], [12.0, 12.0, 12.0], [13.0, 13.0, 13.0], [14.0, 14.0, 14.0], [15.0, 15.0, 15.0], [16.0, 16.0, 16.0], [17.0, 17.0, 17.0], [18.0, 18.0, 18.0], [19.0, 19.0, 19.0], [20.0, 20.0, 20.0], [21.0, 21.0, 21.0], [22.0, 22.0, 22.0], [23.0, 23.0, 23.0], [24.0, 24.0, 24.0], ], "metadatas": None, "documents": None, } ) state.ann_accuracy() state.count() state.fields_match() state.log_size_below_max() state.teardown()
Memory