from abc import ABCMeta, abstractmethod, abstractclassmethod
class DriverBase(metaclass=ABCMeta):
@abstractclassmethod
def is_active(self):
pass
@abstractmethod
def get_current_target(self):
pass
def __init__(self) -> None:
pass
class GPUDriver(DriverBase):
def __init__(self):
# TODO: support other frameworks than torch
import torch
self.get_device_capability = torch.cuda.get_device_capability
try:
from torch._C import _cuda_getCurrentRawStream
self.get_current_stream = _cuda_getCurrentRawStream
except ImportError:
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
self.get_current_device = torch.cuda.current_device
self.set_current_device = torch.cuda.set_device
# TODO: remove once TMA is cleaned up
def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args