"""isort:skip_file"""
# Import order is significant here.
from . import math
from . import extra
from .standard import (
argmax,
argmin,
cdiv,
cumprod,
cumsum,
flip,
interleave,
max,
min,
ravel,
sigmoid,
softmax,
sort,
sum,
swizzle2d,
xor_sum,
zeros,
zeros_like,
)
from .core import (
PropagateNan,
TRITON_MAX_TENSOR_NUMEL,
_experimental_descriptor_load,
_experimental_descriptor_store,
advance,
arange,
associative_scan,
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
bfloat16,
block_type,
broadcast,
broadcast_to,
cat,
cast,
clamp,
const,
const_pointer_type,
constexpr,
debug_barrier,
device_assert,
device_print,
dot,
dtype,
expand_dims,
float16,
float32,
float64,
float8e4b15,
float8e4nv,
float8e4b8,
float8e5,
float8e5b16,
full,
function_type,
histogram,
inline_asm_elementwise,
int1,
int16,
int32,
int64,
int8,
join,
load,
make_block_ptr,
max_constancy,
max_contiguous,
maximum,
minimum,
multiple_of,
num_programs,
permute,
pi32_t,
pointer_type,
program_id,
range,
reduce,
reshape,
split,
static_assert,
static_print,
static_range,
store,
tensor,
trans,
uint16,
uint32,
uint64,
uint8,
view,
void,
where,
)
from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor,
ceil)
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
uint_to_uniform_float,
)
__all__ = [
"PropagateNan",
"TRITON_MAX_TENSOR_NUMEL",
"_experimental_descriptor_load",
"_experimental_descriptor_store",
"abs",
"advance",
"arange",
"argmax",
"argmin",
"associative_scan",
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
"bfloat16",
"block_type",
"broadcast",
"broadcast_to",
"builtin",
"cat",
"cast",
"cdiv",
"ceil",
"clamp",
"const",
"const_pointer_type",
"constexpr",
"cos",
"cumprod",
"cumsum",
"debug_barrier",
"device_assert",
"device_print",
"div_rn",
"dot",
"dtype",
"erf",
"exp",
"exp2",
"expand_dims",
"extra",
"fdiv",
"flip",
"float16",
"float32",
"float64",
"float8e4b15",
"float8e4nv",
"float8e4b8",
"float8e5",
"float8e5b16",
"floor",
"fma",
"full",
"function_type",
"histogram",
"inline_asm_elementwise",
"interleave",
"int1",
"int16",
"int32",
"int64",
"int8",
"ir",
"join",
"load",
"log",
"log2",
"make_block_ptr",
"math",
"max",
"max_constancy",
"max_contiguous",
"maximum",
"min",
"minimum",
"multiple_of",
"num_programs",
"pair_uniform_to_normal",
"permute",
"philox",
"philox_impl",
"pi32_t",
"pointer_type",
"program_id",
"rand",
"rand4x",
"randint",
"randint4x",
"randn",
"randn4x",
"range",
"ravel",
"reduce",
"reshape",
"rsqrt",
"sigmoid",
"sin",
"softmax",
"sort",
"split",
"sqrt",
"sqrt_rn",
"static_assert",
"static_print",
"static_range",
"store",
"sum",
"swizzle2d",
"tensor",
"trans",
"triton",
"uint16",
"uint32",
"uint64",
"uint8",
"uint_to_uniform_float",
"umulhi",
"view",
"void",
"where",
"xor_sum",
"zeros",
"zeros_like",
]
def str_to_ty(name):
if name[0] == "*":
name = name[1:]
if name[0] == "k":
name = name[1:]
ty = str_to_ty(name)
return const_pointer_type(ty)
ty = str_to_ty(name)
return pointer_type(ty)
tys = {
"fp8e4nv": float8e4nv,
"fp8e4b8": float8e4b8,
"fp8e5": float8e5,
"fp8e5b16": float8e5b16,
"fp8e4b15": float8e4b15,
"fp16": float16,
"bf16": bfloat16,
"fp32": float32,
"fp64": float64,
"i1": int1,
"i8": int8,
"i16": int16,
"i32": int32,
"i64": int64,
"u1": int1,
"u8": uint8,
"u16": uint16,
"u32": uint32,
"u64": uint64,
"B": int1,
}
return tys[name]