import threading
import time
from time import sleep
import numpy as np
import uuid
from typing import Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor
import hypothesis.strategies as st
from hypothesis.stateful import (
RuleBasedStateMachine,
Bundle,
rule,
run_state_machine_as_test,
consumes,
MultipleResults,
multiple,
)
from hypothesis import given, settings
from chromadb.segment.impl.manager.cache.cache import SegmentLRUCache, SegmentCache
from chromadb.types import Segment, SegmentScope
class LRUCacheStateMachine(RuleBasedStateMachine):
_model: Dict[uuid.UUID, Segment]
collection_keys = Bundle("collection_keys")
def __init__(self, capacity: int):
super().__init__()
self.evicted_items = []
self._cache = SegmentLRUCache(
capacity=capacity,
size_func=lambda _: 10,
callback=lambda k, v: (self.evicted_items.append(k), self._model.pop(k)),
)
self._model = {}
self._capacity = capacity
@rule(collection_id=collection_keys)
def test_get(self, collection_id) -> None:
if collection_id not in self._model:
return
expected = self._model.get(collection_id)
assert self._cache.get(collection_id) == expected
@rule(collection_id=consumes(collection_keys))
def test_pop(self, collection_id) -> None:
if collection_id not in self._model:
return
expected = self._model.pop(collection_id)
assert self._cache.pop(collection_id) == expected
@rule(target=collection_keys)
def test_set(self) -> MultipleResults[uuid.UUID]:
segment = new_segment()
collection_id = segment["collection"]
self._model[collection_id] = segment
self._cache.set(collection_id, segment)
assert self._cache.get(collection_id) == segment
assert len(self._cache.cache) <= self._capacity
if self.evicted_items:
last_evicted = self.evicted_items[-1]
assert last_evicted not in self._cache.cache
return multiple(collection_id)
def teardown(self):
self._cache.reset()
self._model.clear()
@given(capacity=st.integers(min_value=10, max_value=1000))
@settings(max_examples=20)
def test_caches(capacity: int) -> None:
run_state_machine_as_test(lambda: LRUCacheStateMachine(capacity=capacity)) # type: ignore
def new_segment(collection_id: Optional[uuid.UUID] = None) -> Segment:
if collection_id is None:
collection_id = uuid.uuid4()
return Segment(
id=uuid.uuid4(),
type="test",
scope=SegmentScope.VECTOR,
collection=collection_id,
metadata=None,
file_paths={},
)
class CacheSetup:
def __init__(
self,
cache: SegmentCache,
iterations: Optional[int] = 1000,
num_threads: Optional[int] = 50,
):
self.cache: SegmentCache = cache
self.iterations = iterations
self.num_threads = num_threads
self.metrics: Dict[str, Any] = {
"errors": [],
"time_to_first_error": None,
"error_timings": [],
}
self.lock = threading.Lock()
def _get_segment_disk_size(_: uuid.UUID) -> int:
return np.random.randint(1, 10)
def callback_cache_evict(_: Segment) -> None:
pass
@given(
capacity=st.integers(min_value=1, max_value=1000),
num_threads=st.integers(min_value=1, max_value=40),
iterations=st.integers(min_value=1, max_value=800),
)
@settings(max_examples=20)
def test_thread_safety(capacity: int, num_threads: int, iterations: int) -> None:
"""Test that demonstrates thread safety issues in the LRU cache"""
cache_setup = CacheSetup(
SegmentLRUCache(
capacity=capacity,
callback=lambda k, v: callback_cache_evict(v),
size_func=lambda k: _get_segment_disk_size(k),
),
iterations=iterations,
num_threads=num_threads,
)
def worker():
"""Worker that performs multiple cache operations"""
_iterations = 0
start_time = time.perf_counter()
try:
while _iterations <= cache_setup.iterations:
_iterations += 1
cache_keys = list(cache_setup.cache.cache.keys())
if np.random.uniform(0, 1) < 0.5 and len(cache_keys) > 0:
cache_setup.cache.get(np.random.choice(cache_keys))
else:
key = uuid.uuid4()
segment = new_segment(key)
cache_setup.cache.set(key, segment)
sleep(np.random.uniform(0, 0.01))
if np.random.uniform(0, 1) < 0.3 and len(cache_keys) > 0:
cache_setup.cache.get(np.random.choice(cache_keys))
if np.random.uniform(0, 1) < 0.05:
cache_setup.cache.reset()
except Exception as e:
with cache_setup.lock:
cache_setup.metrics["errors"].append(e)
time_to_error = time.perf_counter() - start_time
cache_setup.metrics["error_timings"].append(time_to_error)
if cache_setup.metrics["time_to_first_error"] is None:
cache_setup.metrics["time_to_first_error"] = time_to_error
with ThreadPoolExecutor(max_workers=cache_setup.num_threads) as executor:
for _ in range(cache_setup.num_threads):
executor.submit(worker)
print(cache_setup.metrics)
assert len(cache_setup.metrics["errors"]) == 0, "Thread safety issues found"