import functools
import triton
from triton._C.libproton import proton as libproton
from .hook import register_triton_hook, unregister_triton_hook
from .flags import set_profiling_off, set_profiling_on, is_command_line
from typing import Optional
DEFAULT_PROFILE_NAME = "proton"
def _select_backend() -> str:
backend = triton.runtime.driver.active.get_current_target().backend
if backend == "cuda":
return "cupti"
elif backend == "hip":
return "roctracer"
else:
raise ValueError("No backend is available for the current target.")
def start(
name: Optional[str] = None,
*,
context: Optional[str] = "shadow",
data: Optional[str] = "tree",
backend: Optional[str] = None,
hook: Optional[str] = None,
):
"""
Start profiling with the given name and backend.
Usage:
```python
proton.start("my_profile")
# do something
proton.finalize()
```
Args:
name (str, optional): The name (with path) of the profiling session.
If not provided, the default name is "~/proton.hatchet".
backend (str, optional): The backend to use for profiling.
Available options are ["cupti"].
Defaults to None, which automatically selects the backend matching the current active runtime.
context (str, optional): The context to use for profiling.
Available options are ["shadow", "python"].
Defaults to "shadow".
data (str, optional): The data structure to use for profiling.
Available options are ["tree"].
Defaults to "tree".
hook (str, optional): The hook to use for profiling.
Available options are [None, "triton"].
Defaults to None.
Returns:
session (int): The session ID of the profiling session.
"""
if is_command_line():
# Ignore the start() call if the script is run from the command line.
return
if name is None:
name = DEFAULT_PROFILE_NAME
if backend is None:
backend = _select_backend()
set_profiling_on()
if hook and hook == "triton":
register_triton_hook()
return libproton.start(name, context, data, backend)
def activate(session: Optional[int] = 0) -> None:
"""
Activate the specified session.
The profiling session will be active and data will be recorded.
Args:
session (int): The session ID of the profiling session. Defaults to 0 (the first session started.)
Returns:
None
"""
if is_command_line() and session != 0:
raise ValueError("Only one session can be activated when running from the command line.")
libproton.activate(session)
def deactivate(session: Optional[int] = 0) -> None:
"""
Stop the specified session.
The profiling session's data will still be in the memory, but no more data will be recorded.
Args:
session (int): The session ID of the profiling session. Defaults to 0 (the first session started.)
Returns:
None
"""
if is_command_line() and session != 0:
raise ValueError("Only one session can be deactivated when running from the command line.")
libproton.deactivate(session)
def finalize(session: Optional[int] = None, output_format: str = "hatchet") -> None:
"""
Finalizes a profiling session.
Flush and write the profiling data to the file specified by the session name.
Args:
session (int, optional): The session ID to finalize. If None, all sessions are finalized. Defaults to None.
output_format (str, optional): The output format for the profiling results.
Aavailable options are ["hatchet"].
Returns:
None
"""
if session is None:
set_profiling_off()
libproton.finalize_all(output_format)
unregister_triton_hook()
else:
if is_command_line() and session != 0:
raise ValueError("Only one session can be finalized when running from the command line.")
libproton.finalize(session, output_format)
def _profiling(
func,
name: Optional[str] = None,
context: Optional[str] = "shadow",
data: Optional[str] = "tree",
backend: Optional[str] = None,
hook: Optional[str] = None,
):
"""
Context manager for profiling. Internally use only.
Args:
See start() for the arguments.
Returns:
wrapper (function): The wrapped function.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
session = start(name, context=context, data=data, backend=backend, hook=hook)
ret = func(*args, **kwargs)
deactivate(session)
return ret
return wrapper
def profile(
func=None,
*,
name: Optional[str] = None,
context: Optional[str] = "shadow",
data: Optional[str] = "tree",
backend: Optional[str] = None,
hook: Optional[str] = None,
):
"""
Decorator for profiling.
Usage:
```python
@proton.profile
def foo():
pass
```
Args:
See start() for the arguments.
Returns:
decorator (function): The decorator function.
"""
if func is None:
# It's being used with parentheses, so return a decorator
def decorator(f):
return _profiling(f, name=name, context=context, data=data, backend=backend, hook=hook)
return decorator
else:
# It's being used without parentheses, so apply the decorator directly
return _profiling(func, name=name, context=context, data=data, backend=backend, hook=hook)