from hypothesis import given, settings from typing import Any, Dict import hypothesis.strategies as st import pytest from chromadb.api import ServerAPI from chromadb.config import System from chromadb.test.conftest import _fastapi_fixture from chromadb.test.auth.strategies import ( random_token, random_token_transport_header, token_test_conf, ) @settings(max_examples=10) @given(token_test_conf(), random_token_transport_header(), st.booleans()) def test_fastapi_server_token_authn_allows_when_it_should_allow( tconf: Dict[str, Any], transport_header: str, persistence: bool ) -> None: for user in tconf["users"]: for token in user["tokens"]: api = _fastapi_fixture( is_persistent=persistence, chroma_auth_token_transport_header=transport_header, chroma_server_authn_provider="chromadb.auth.token_authn.TokenAuthenticationServerProvider", chroma_server_authn_credentials_file=tconf["filename"], chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", chroma_client_auth_credentials=token, ) _sys: System = next(api) _sys.reset_state() _api = _sys.instance(ServerAPI) _api.heartbeat() assert _api.list_collections() == [] @settings(max_examples=10) @given( token_test_conf(), random_token(), random_token_transport_header(), st.booleans() ) def test_fastapi_server_token_authn_rejects_when_it_should_reject( tconf: Dict[str, Any], unauthorized_token: str, transport_header: str, persistence: bool, ) -> None: # Make sure we actually have an unauthorized token for user in tconf["users"]: for t in user["tokens"]: if t == unauthorized_token: return for user in tconf["users"]: for t in user["tokens"]: _api = _fastapi_fixture( is_persistent=persistence, chroma_auth_token_transport_header=transport_header, chroma_server_authn_provider="chromadb.auth.token_authn.TokenAuthenticationServerProvider", chroma_server_authn_credentials_file=tconf["filename"], chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", chroma_client_auth_credentials=unauthorized_token, ) _sys: System = next(_api) _sys.reset_state() api = _sys.instance(ServerAPI) api.heartbeat() with pytest.raises(Exception) as e: api.list_collections() assert "Forbidden" in str(e)
Memory