import uuid from random import randint from typing import cast, List, Any, Dict import hypothesis import numpy as np import pytest import hypothesis.strategies as st from hypothesis import given, settings from chromadb.api import ClientAPI from chromadb.api.types import Embeddings, Metadatas from chromadb.test.conftest import ( reset, NOT_CLUSTER_ONLY, override_hypothesis_profile, ) import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase from chromadb.utils.batch_utils import create_batches collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") # Hypothesis tends to generate smaller values so we explicitly segregate the # the tests into tiers, Small, Medium. Hypothesis struggles to generate large # record sets so we explicitly create a large record set without using Hypothesis @given( collection=collection_st, record_set=strategies.recordsets(collection_st, min_size=1, max_size=500), should_compact=st.booleans(), ) @settings( deadline=None, parent=override_hypothesis_profile( normal=hypothesis.settings(max_examples=500), fast=hypothesis.settings(max_examples=200), ), ) def test_add_small( client: ClientAPI, collection: strategies.Collection, record_set: strategies.RecordSet, should_compact: bool, ) -> None: _test_add(client, collection, record_set, should_compact) @given( collection=collection_st, record_set=strategies.recordsets( collection_st, min_size=250, max_size=500, num_unique_metadata=5, min_metadata_size=1, max_metadata_size=5, ), should_compact=st.booleans(), ) @settings( deadline=None, parent=override_hypothesis_profile( normal=hypothesis.settings(max_examples=10), fast=hypothesis.settings(max_examples=5), ), suppress_health_check=[ hypothesis.HealthCheck.too_slow, hypothesis.HealthCheck.data_too_large, hypothesis.HealthCheck.large_base_example, hypothesis.HealthCheck.function_scoped_fixture, ], ) def test_add_medium( client: ClientAPI, collection: strategies.Collection, record_set: strategies.RecordSet, should_compact: bool, ) -> None: # Cluster tests transmit their results over grpc, which has a payload limit # This breaks the ann_accuracy invariant by default, since # the vector reader returns a payload of dataset size. So we need to batch # the queries in the ann_accuracy invariant _test_add(client, collection, record_set, should_compact, batch_ann_accuracy=True) def _test_add( client: ClientAPI, collection: strategies.Collection, record_set: strategies.RecordSet, should_compact: bool, batch_ann_accuracy: bool = False, ) -> None: reset(client) # TODO: Generative embedding functions coll = client.create_collection( name=collection.name, metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) initial_version = cast(int, coll.get_model()["version"]) normalized_record_set = invariants.wrap_all(record_set) # TODO: The type of add() is incorrect as it does not allow for metadatas # like [{"a": 1}, None, {"a": 3}] for batch in create_batches( api=client, ids=cast(List[str], record_set["ids"]), embeddings=cast(Embeddings, record_set["embeddings"]), metadatas=cast(Metadatas, record_set["metadatas"]), documents=cast(List[str], record_set["documents"]), ): coll.add(*batch) # Only wait for compaction if the size of the collection is # some minimal size if ( not NOT_CLUSTER_ONLY and should_compact and len(normalized_record_set["ids"]) > 10 ): # Wait for the model to be updated wait_for_version_increase(client, collection.name, initial_version) invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) n_results = max(1, (len(normalized_record_set["ids"]) // 10)) if batch_ann_accuracy: batch_size = 10 for i in range(0, len(normalized_record_set["ids"]), batch_size): invariants.ann_accuracy( coll, cast(strategies.RecordSet, normalized_record_set), n_results=n_results, embedding_function=collection.embedding_function, query_indices=list( range(i, min(i + batch_size, len(normalized_record_set["ids"]))) ), ) else: invariants.ann_accuracy( coll, cast(strategies.RecordSet, normalized_record_set), n_results=n_results, embedding_function=collection.embedding_function, ) # Hypothesis struggles to generate large record sets so we explicitly create # a large record set def create_large_recordset( min_size: int = 45000, max_size: int = 50000, ) -> strategies.RecordSet: size = randint(min_size, max_size) ids = [str(uuid.uuid4()) for _ in range(size)] metadatas = [{"some_key": f"{i}"} for i in range(size)] documents = [f"Document {i}" for i in range(size)] embeddings = [[1, 2, 3] for _ in range(size)] record_set: Dict[str, List[Any]] = { "ids": ids, "embeddings": cast(Embeddings, embeddings), "metadatas": metadatas, "documents": documents, } return cast(strategies.RecordSet, record_set) @given(collection=collection_st, should_compact=st.booleans()) @settings(deadline=None, max_examples=5) def test_add_large( client: ClientAPI, collection: strategies.Collection, should_compact: bool ) -> None: reset(client) record_set = create_large_recordset( min_size=10000, max_size=50000, ) coll = client.create_collection( name=collection.name, metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) initial_version = cast(int, coll.get_model()["version"]) for batch in create_batches( api=client, ids=cast(List[str], record_set["ids"]), embeddings=cast(Embeddings, record_set["embeddings"]), metadatas=cast(Metadatas, record_set["metadatas"]), documents=cast(List[str], record_set["documents"]), ): coll.add(*batch) if ( not NOT_CLUSTER_ONLY and should_compact and len(normalized_record_set["ids"]) > 10 ): # Wait for the model to be updated, since the record set is larger, add some additional time wait_for_version_increase( client, collection.name, initial_version, additional_time=240 ) invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) @given(collection=collection_st) @settings(deadline=None, max_examples=1) def test_add_large_exceeding( client: ClientAPI, collection: strategies.Collection ) -> None: reset(client) record_set = create_large_recordset( min_size=client.get_max_batch_size(), max_size=client.get_max_batch_size() + 100, # Exceed the max batch size by 100 records ) coll = client.create_collection( name=collection.name, metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) with pytest.raises(Exception) as e: coll.add(**record_set) # type: ignore[arg-type] assert "exceeds maximum batch size" in str(e.value) # TODO: This test fails right now because the ids are not sorted by the input order @pytest.mark.xfail( reason="This is expected to fail right now. We should change the API to sort the \ ids by input order." ) def test_out_of_order_ids(client: ClientAPI) -> None: reset(client) ooo_ids = [ "40", "05", "8", "6", "10", "01", "00", "3", "04", "20", "02", "9", "30", "11", "13", "2", "0", "7", "06", "5", "50", "12", "03", "4", "1", ] coll = client.create_collection( "test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore ) embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids] coll.add(ids=ooo_ids, embeddings=embeddings) get_ids = coll.get(ids=ooo_ids)["ids"] assert get_ids == ooo_ids def test_add_partial(client: ClientAPI) -> None: """Tests adding a record set with some of the fields set to None.""" reset(client) coll = client.create_collection("test") # TODO: We need to clean up the api types to support this typing coll.add( ids=["1", "2", "3"], # All embeddings must be provided, or else None - no partial lists allowed embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore # Metadatas can always be partial metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore # Documents are optional if embeddings are provided documents=["a", "b", None], # type: ignore ) results = coll.get() assert results["ids"] == ["1", "2", "3"] assert results["metadatas"] == [{"a": 1}, None, {"a": 3}] assert results["documents"] == ["a", "b", None]
Memory