import importlib
import json
import os
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional
import hashlib
def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
def default_override_dir():
return os.path.join(Path.home(), ".triton", "override")
def default_dump_dir():
return os.path.join(Path.home(), ".triton", "dump")
class CacheManager(ABC):
def __init__(self, key):
pass
@abstractmethod
def get_file(self, filename) -> Optional[str]:
pass
@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
@abstractmethod
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
pass
@abstractmethod
def put_group(self, filename: str, group: Dict[str, str]):
pass
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
def has_file(self, filename) -> bool:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))
def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
return self._make_path(filename)
else:
return None
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
# Invalid group data.
if child_paths is None:
return None
result = {}
for c, p in child_paths.items():
if os.path.exists(p):
result[c] = p
return result
# Note a group of pushed files as being part of a group
def put_group(self, filename: str, group: Dict[str, str]) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
grp_contents = json.dumps({"child_paths": group})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)
def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
# we use the PID in case a bunch of these around so we can see what PID made it
pid = os.getpid()
# use tempfile to be robust against program interruptions
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
return filepath
class RemoteCacheBackend:
"""
A backend implementation for accessing a remote/distributed cache.
"""
def __init__(self, key: str):
pass
@abstractmethod
def get(self, filenames: List[str]) -> Dict[str, bytes]:
pass
@abstractmethod
def put(self, filename: str, data: bytes):
pass
class RedisRemoteCacheBackend(RemoteCacheBackend):
def __init__(self, key):
import redis
self._key = key
self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
self._redis = redis.Redis(
host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
)
def _get_key(self, filename: str) -> str:
return self._key_fmt.format(key=self._key, filename=filename)
def get(self, filenames: List[str]) -> Dict[str, str]:
results = self._redis.mget([self._get_key(f) for f in filenames])
return {filename: result for filename, result in zip(filenames, results) if result is not None}
def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
self._redis.set(self._get_key(filename), data)
class RemoteCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
module_path, clz_nme = remote_cache_manager.split(":")
module = importlib.import_module(module_path)
remote_cache_cls = getattr(module, clz_nme)
self._backend = remote_cache_cls(key)
self._override = override
self._dump = dump
# Use a `FileCacheManager` to materialize remote cache paths locally.
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
def _materialize(self, filename: str, data: bytes):
# We use a backing `FileCacheManager` to provide the materialized data.
return self._file_cache_manager.put(data, filename, binary=True)
def get_file(self, filename: str) -> Optional[str]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_file(filename)
# We always check the remote cache backend -- even if our internal file-
# based cache has the item -- to make sure LRU accounting works as
# expected.
results = self._backend.get([filename])
if len(results) == 0:
return None
(_, data), = results.items()
return self._materialize(filename, data)
def put(self, data, filename: str, binary=True) -> str:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put(data, filename, binary=binary)
if not isinstance(data, bytes):
data = str(data).encode("utf-8")
self._backend.put(filename, data)
return self._materialize(filename, data)
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_group(filename)
grp_filename = f"__grp__{filename}"
grp_filepath = self.get_file(grp_filename)
if grp_filepath is None:
return None
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
result = None
# Found group data.
if child_paths is not None:
result = {}
for child_path, data in self._backend.get(child_paths).items():
result[child_path] = self._materialize(child_path, data)
return result
def put_group(self, filename: str, group: Dict[str, str]):
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put_group(filename, group)
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename)
__cache_cls = FileCacheManager
__cache_cls_nme = "DEFAULT"
def get_cache_manager(key) -> CacheManager:
import os
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
global __cache_cls
global __cache_cls_nme
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)
__cache_cls_nme = user_cache_manager
return __cache_cls(key)
def get_override_manager(key) -> CacheManager:
return __cache_cls(key, override=True)
def get_dump_manager(key) -> CacheManager:
return __cache_cls(key, dump=True)
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
for kw in kwargs:
key = f"{key}-{kwargs.get(kw)}"
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
return key