import threading
from functools import wraps
from typing import Optional, Union
from .flags import get_profiling_on
from triton._C.libproton import proton as libproton
_local = threading.local()
MetricValueType = Union[float, int]
PropertyValueType = Union[float, int, str]
class scope:
"""
A context manager and decorator for entering and exiting a scope.
Usage:
context manager:
```python
with proton.scope("test0", {metric_name: metric_value}):
foo[1,](x, y)
```
decoarator:
```python
@proton.scope("test0", {metric_name: metric_value})
def foo(x, y):
...
```
Args:
name (str): The name of the scope.
metrics (dict[str, float], optional): The metrics of the scope. Default is None.
"""
def __init__(self, name: str, metrics: Optional[dict[str, MetricValueType]] = None,
properties: Optional[dict[str, PropertyValueType]] = None) -> None:
self._name = name
self._metrics = metrics
self._properties = properties
def __enter__(self):
if not get_profiling_on():
return self
self._id = libproton.record_scope()
libproton.enter_scope(self._id, self._name)
if self._metrics:
libproton.add_metrics(self._id, self._metrics)
if self._properties:
libproton.set_properties(self._id, self._properties)
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
if not get_profiling_on():
return
libproton.exit_scope(self._id, self._name)
def __call__(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
if get_profiling_on():
id = libproton.record_scope()
libproton.enter_scope(id, self._name)
if self._metrics:
libproton.add_metrics(id, self._metrics)
if self._properties:
libproton.set_properties(id, self._properties)
ret = func(*args, **kwargs)
if get_profiling_on():
libproton.exit_scope(id, self._name)
return ret
return wrapper
def enter_scope(name: str, *, triton_op: bool = False, metrics: Optional[dict[str, MetricValueType]] = None,
properties: Optional[dict[str, PropertyValueType]] = None) -> int:
if not get_profiling_on():
return -1
id = libproton.record_scope()
if not hasattr(_local, "scopes"):
_local.scopes = []
_local.scopes.append((id, name))
if triton_op:
libproton.enter_op(id, name)
else:
libproton.enter_scope(id, name)
if metrics:
libproton.add_metrics(id, metrics)
if properties:
libproton.set_properties(id, properties)
return id
def exit_scope(triton_op: bool = False) -> int:
if not get_profiling_on():
return -1
id, name = _local.scopes.pop()
if triton_op:
libproton.exit_op(id, name)
else:
libproton.exit_scope(id, name)
return id