diff --git a/bitblas/base/__init__.py b/bitblas/base/__init__.py index 122c44cbd..c6235ea42 100644 --- a/bitblas/base/__init__.py +++ b/bitblas/base/__init__.py @@ -1,18 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """Base infra""" from .analysis import ( - BlockInfo, - IterInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, - is_broadcast_epilogue, - normalize_prim_func, -) -from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial -from .schedule_rule import ScheduleRule -from .transform import ApplyDefaultSchedule, ApplyFastTuning -from .utils import fast_tune, fast_tune_with_dynamic_range + BlockInfo, # noqa: F401 + IterInfo, # noqa: F401 + collect_block_iter_vars_used_in_access_region, # noqa: F401 + collect_vars_used_in_prim_expr, # noqa: F401 + detect_dominant_read, # noqa: F401 + is_broadcast_epilogue, # noqa: F401 + normalize_prim_func, # noqa: F401 +) # noqa: F401 +from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401 +from .schedule_rule import ScheduleRule # noqa: F401 +from .transform import ApplyDefaultSchedule, ApplyFastTuning # noqa: F401 +from .utils import fast_tune, fast_tune_with_dynamic_range # noqa: F401 from .roller import * +from .arch import CUDA, CDNA # noqa: F401 diff --git a/bitblas/base/arch/__init__.py b/bitblas/base/arch/__init__.py index 9cb036792..27581d2f8 100644 --- a/bitblas/base/arch/__init__.py +++ b/bitblas/base/arch/__init__.py @@ -3,6 +3,7 @@ from .arch_base import TileDevice from .cuda import * from .cpu import * +from .cdna import * def get_arch(target: tvm.target.Target) -> TileDevice: @@ -10,5 +11,7 @@ def get_arch(target: tvm.target.Target) -> TileDevice: return CUDA(target) elif target.kind.name == "llvm": return CPU(target) + elif target.kind.name == "hip": + return CDNA(target) else: raise ValueError(f"Unsupported target: {target.kind.name}") diff --git a/bitblas/base/arch/cdna.py b/bitblas/base/arch/cdna.py new file mode 100644 index 000000000..f6805dc98 --- /dev/null +++ b/bitblas/base/arch/cdna.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm +from tvm.target import Target +from .arch_base import TileDevice +from typing import List, Union + + +class CDNA(TileDevice): + + def __init__(self, target: Union[Target, str]): + if isinstance(target, str): + target = tvm.target.Target(target) + self.target = target + device = tvm.runtime.rocm(0) + if not device.exist: + raise RuntimeError("Cannot find HIP device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "CDNA" + self.smem_cap = device.max_shared_memory_per_block + self.compute_max_core = device.multi_processor_count + self.warp_size = device.warp_size + self.compute_capability = device.compute_version.replace(".", "") + self.reg_cap: int = 32768 + self.max_smem_usage: int = 2 * self.smem_cap + self.sm_partition: int = 4 + self.l2_cache_size_bytes: int = target.l2_cache_size_bytes + self.transaction_size: List[int] = [32, 128] # in bytes + + self.bandwidth: List[int] = [1300, 14000] diff --git a/bitblas/base/roller/hint.py b/bitblas/base/roller/hint.py index 21722b595..36a1fb7a0 100644 --- a/bitblas/base/roller/hint.py +++ b/bitblas/base/roller/hint.py @@ -161,7 +161,7 @@ def __init__(self) -> None: # Special axes tiling info self.block = [] self.thread = [] - # Special axes for MMA + # Special axes for MFMA self.warp = [] # Reduce axes tiling info self.rstep = [] diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index e6e427b5a..f355da044 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -13,7 +13,7 @@ from tvm.relax.expr import Function import bitblas from .analysis import get_root_block, get_reduction_blocks, find_var_from_func -from bitblas.base.arch import TileDevice, CUDA +from bitblas.base.arch import TileDevice, CUDA, CDNA from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.base.roller.hint import Hint from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags @@ -357,8 +357,8 @@ def fast_tune( if not isinstance(func, tir.PrimFunc): raise ValueError("Only support func is PrimFunc") # pragma: no cover - if target.kind.name != "cuda": - logger.error("Only support CUDA target") + if target.kind.name not in ["cuda", "hip"]: + logger.error("Only support CUDA and hip target") return None, None specilized_func = func @@ -385,7 +385,12 @@ def fast_tune( var: shape.astype(var.dtype) }).with_attr("is_specialized") - arch = CUDA(target) + if target.kind.name == "cuda": + arch = CUDA(target) + elif target.kind.name == "hip": + arch = CDNA(target) + else: + raise ValueError(f"Unsupported target: {target.kind.name}") policy = DefaultPolicy(func=func, arch=arch) try: diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 1a9ababd2..642198060 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -30,24 +30,41 @@ def load_lib(self): def compile_lib(self, timeout: float = None, with_tl: bool = False): arch = self.arch - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) - compute_version = arch.compute_capability - libpath = src.name.replace(".cu", ".so") - - command = [ - "nvcc", - "-std=c++17", - "-Xcudafe", - "--diag_suppress=177", - "--compiler-options", - "'-fPIC'", - "-lineinfo", - "--shared", - src.name, - "-lcuda", - "-gencode", - f"arch=compute_{compute_version},code=sm_{compute_version}", - ] + platform = arch.platform + if platform == "CUDA": + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + compute_version = arch.compute_capability + libpath = src.name.replace(".cu", ".so") + + command = [ + "nvcc", + "-std=c++17", + "-Xcudafe", + "--diag_suppress=177", + "--compiler-options", + "'-fPIC'", + "-lineinfo", + "--shared", + src.name, + "-lcuda", + "-gencode", + f"arch=compute_{compute_version},code=sm_{compute_version}", + ] + + elif platform == "CDNA": + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) + libpath = src.name.replace(".cpp", ".so") + + command = [ + "hipcc", + "-std=c++17", + "-fPIC", + "--shared", + src.name, + ] + + else: + raise ValueError(f"Unsupported platform: {platform}") if with_tl: install_tvm_path = os.path.join( diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index b57981515..018454105 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -85,6 +85,9 @@ def get_cuda_init_func(self): init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs + def get_stream_type(self, function_args): + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + def update_lib_code(self, code: str): # Update the library code with the given code string self.lib_code = code @@ -113,7 +116,7 @@ def update_lib_code(self, code: str): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + self.get_stream_type(function_args) # Format the function arguments for declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -249,7 +252,7 @@ def legalize_c(p): last_range = 0 num_items = len(function_informations) _call_str = """""" - for function_name, info in function_informations.items(): + for last_range, (function_name, info) in enumerate(function_informations.items()): # Prepare block and grid configurations for kernel launches block_info, grid_info = info["block_info"], info["grid_info"] block_str = "dim3({}, {}, {})".format( @@ -292,7 +295,6 @@ def legalize_c(p): if last_range == num_items - 1: call_str += (" else {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( function_name, grid_str, block_str, smem_str, call_args)) - last_range += 1 _call_str += call_str # Wrap the kernel dispatch logic in an external C function @@ -368,6 +370,27 @@ def compare_map_objects(map_obj): return lib_code +class TIRHIPSourceWrapper(TIRCUDASourceWrapper): + + def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice): + super().__init__(scheduled_ir_module, source, arch) + + def get_hip_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + if self.dynamic_smem_buf is not None: + call_str = ( + PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, + self.dynamic_smem_buf)) + # Format the initialization function using the call_str + init_funcs = PREDEF_INIT_FUNC.format(call_str) + return init_funcs + + def get_stream_type(self, function_args): + function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},) + + class TIRWrapper(BaseWrapper): def __init__(self, arch: TileDevice): @@ -382,6 +405,12 @@ def assign_optimized_module(self, scheduled_ir_module: IRModule): # Get Scheduled Rt Module and return source to be compiled def wrap(self, c_source: str, is_dynamic: bool = False): assert self.scheduled_ir_module is not None, "Please assign optimized module first." - wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic + if self.arch.platform == "CUDA": + wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic + elif self.arch.platform == "CDNA": + wrapper_class = TIRHIPSourceWrapper + else: + raise ValueError(f"Unsupported platform: {self.arch.platform}") + wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.arch) return wrapper.lib_code diff --git a/bitblas/gpu/__init__.py b/bitblas/gpu/__init__.py index df0635b3c..72192f60e 100644 --- a/bitblas/gpu/__init__.py +++ b/bitblas/gpu/__init__.py @@ -13,6 +13,7 @@ Matmul, # noqa: F401 MatmulTensorizationMMA, # noqa: F401 MatmulTensorizationWMMA, # noqa: F401 + MatmulTensorizationMFMA, # noqa: F401 ) from .matmul_mma_dequantize import ( MatmulTensorizationMMAWithDequantizeInfo, # noqa: F401 diff --git a/bitblas/gpu/intrin/__init__.py b/bitblas/gpu/intrin/__init__.py index d9d9ba942..b6f1992f3 100644 --- a/bitblas/gpu/intrin/__init__.py +++ b/bitblas/gpu/intrin/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .lop3 import get_lop3_intrin_group # noqa: F401 +from .hip import get_mfma_intrin_group # noqa: F401 diff --git a/bitblas/gpu/intrin/hip.py b/bitblas/gpu/intrin/hip.py new file mode 100644 index 000000000..9883eaed1 --- /dev/null +++ b/bitblas/gpu/intrin/hip.py @@ -0,0 +1,1029 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.runtime import convert +from tvm.tir.expr import Cast, IntImm +from tvm.tir.function import TensorIntrin +from tvm.script import tir as T +from tvm._ffi import register_func +from typing import Dict, Tuple, Optional, List +from typing_extensions import Literal +from tvm.tir.function import PrimFunc + +lift = convert + +WARP_SIZE = 64 +M_DIM = 16 +N_DIM = 16 + + +def shared_16x4_to_local_64x1_layout_A(i, j): + thread_id = (j * 16 + i) + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id): + i = thread_id % 16 + j = thread_id // 16 + return i, j + + +def shared_4x16_to_local_64x1_layout_B(i, j): + thread_id = (i * 16 + j) + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id): + i = thread_id // 16 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_C(i, j): + thread_id = j + (i // 4) * 16 + local = (i % 4) + return thread_id, local + + +@register_func("tir.index_map.shared_16x16_to_ldmatrix_64x4_layout") +def shared_16x16_to_ldmatrix_64x4_layout(ind): + i, j = ind[0], ind[1] + thread_id, local_id = shared_16x16_to_local_64x4_layout_C(i, j) + return convert([thread_id, local_id]) + + +def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): + i = thread_id % 16 + j = (thread_id // 16) * 4 + local_id + return i, j + + +def shared_16x16_to_local_64x4_layout_A(i, j): + thread_id = i + 16 * (j // 4) + local = (j % 4) + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_B(i, j): + thread_id = j + (i // 4) * 16 + local = (i % 4) + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id): + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j + + +def get_mma_fill_intrin(dtype, local_size): + zero = IntImm("int32", 0).astype(dtype) + + # Assume M = N = 16 + index_map = shared_16x16_to_local_64x4_layout_C + + @T.prim_func + def mma_fill_desc(a: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + i, j = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = T.meta_var(index_map(i, j)) + T.reads() + T.writes(C_warp[thread_id, local_id]) + C_warp[thread_id, local_id] = zero + + @T.prim_func + def mma_fill_impl(a: T.handle) -> None: + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1) + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + for tx in T.thread_binding(WARP_SIZE, "threadIdx.x"): + for local_id in T.serial(0, local_size): + C_warp[tx, local_id] = zero + + return mma_fill_desc, mma_fill_impl + + +def get_mfma_load_intrin( + k_dim=4, + dtype="float32", + scope="shared", + is_b=False, + transposed=False, +): + + local_size = (M_DIM * k_dim) // WARP_SIZE if not is_b else (N_DIM * k_dim) // WARP_SIZE + memory_shape = (M_DIM, k_dim) + if is_b: + memory_shape = (N_DIM, k_dim) if transposed else (k_dim, N_DIM) + + row_dim, col_dim = memory_shape + + if k_dim == 4: + index_map = shared_16x4_to_local_64x1_layout_A + reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A + if is_b: + index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B + reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + elif k_dim == 16: + index_map = shared_16x16_to_local_64x4_layout_A + reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A + + if is_b: + index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B + reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + else: + raise ValueError("k_dim must be 4 or 16 currently") + + @T.prim_func + def mfma_load_desc(reg_handle: T.handle, memory_handle: T.handle) -> None: + memory = T.match_buffer( + memory_handle, + memory_shape, + dtype, + offset_factor=1, + scope=scope, + ) + reg = T.match_buffer( + reg_handle, (WARP_SIZE, local_size), dtype, offset_factor=1, scope="warp") + + with T.block("root"): + T.reads(memory[0:row_dim, 0:col_dim]) + T.writes(reg[0:WARP_SIZE, 0:local_size]) + + for ax0, ax1 in T.grid(row_dim, col_dim): + with T.block("memory_reg"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(memory[v0, v1]) + + thread_id, local_id = T.meta_var(index_map(v0, v1)) + T.writes(reg[thread_id, local_id]) + reg[thread_id, local_id] = memory[v0, v1] + + @T.prim_func + def mfma_load_impl(reg_handle: T.handle, memory_handle: T.handle) -> None: + s0 = T.int32() + s1 = T.int32() + + memory = T.match_buffer( + memory_handle, + memory_shape, + dtype, + align=64, + offset_factor=1, + scope=scope, + strides=[s0, s1], + ) + reg = T.match_buffer( + reg_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=1, scope="warp") + + with T.block("root"): + T.reads(memory[0:row_dim, 0:col_dim]) + T.writes(reg[0:WARP_SIZE, 0:local_size]) + for tx in T.thread_binding(WARP_SIZE, "threadIdx.x"): + for local_id in T.serial(local_size): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + reg[tx, local_id] = memory[row, col] + + return mfma_load_desc, mfma_load_impl + + +def get_mfma_intrin(k_dim, in_dtype="float32", out_dtype="float32", b_transposed=False): + + local_size = (M_DIM * k_dim) // WARP_SIZE + local_size_out = (M_DIM * N_DIM) // WARP_SIZE + compute_in_dtype = in_dtype if local_size == 1 else f"{in_dtype}x{local_size}" + compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" + + if k_dim == 4: + index_map_A = shared_16x4_to_local_64x1_layout_A + index_map_B = shared_4x16_to_local_64x1_layout_B + index_map_C = shared_16x16_to_local_64x4_layout_C + elif k_dim == 16: + index_map_A = shared_16x16_to_local_64x4_layout_A + index_map_B = shared_16x16_to_local_64x4_layout_A if b_transposed else shared_16x16_to_local_64x4_layout_B + index_map_C = shared_16x16_to_local_64x4_layout_C + else: + raise ValueError("k_dim must be 4 or 16 currently") + + out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype] + + in_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[in_dtype] + + mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" + + def maybe_cast(v): + if out_dtype != in_dtype: + return Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + @T.prim_func + def mfma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, offset_factor=16, scope="warp") + B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=16, scope="warp") + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + + for i, j, k in T.grid(M_DIM, N_DIM, k_dim): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j)) + + thread_id_C, local_id_C = T.meta_var(index_map_C(i, j)) + thread_id_A, local_id_A = T.meta_var(index_map_A(i, k)) + thread_id_B, local_id_B = T.meta_var(index_map_B(b_row_ind, b_col_ind)) + + T.reads( + C[thread_id_C, local_id_C], + A[thread_id_A, local_id_A], + B[thread_id_B, local_id_B], + ) + T.writes(C[thread_id_C, local_id_C]) + + C[thread_id_C, local_id_C] += maybe_cast( + A[thread_id_A, local_id_A]) * maybe_cast(B[thread_id_B, local_id_B]) + + @T.prim_func + def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, offset_factor=16, scope="warp") + B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=16, scope="warp") + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads( + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + C[0:WARP_SIZE, 0:local_size_out], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + T.evaluate( + T.tvm_mfma( + mfma_suffix, + "row", + "row", + compute_in_dtype, + compute_in_dtype, + compute_out_dtype, + A.data, + A.elem_offset // (WARP_SIZE * local_size_out), + B.data, + B.elem_offset // (WARP_SIZE * local_size_out), + C.data, + C.elem_offset // (WARP_SIZE * local_size_out), + dtype=compute_out_dtype, + )) + + @T.prim_func + def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=1, scope="warp") + + with T.block("root"): + T.reads( + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + C[0:WARP_SIZE, 0:local_size_out], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.tvm_mfma( + mfma_suffix, + "row", + "row", + compute_in_dtype, + compute_in_dtype, + compute_out_dtype, + T.call_intrin("int32", "tir.reinterpret", A.data), + A.elem_offset, + T.call_intrin("int32", "tir.reinterpret", B.data), + B.elem_offset, + C.data, + C.elem_offset // (WARP_SIZE * local_size_out), + dtype=compute_out_dtype, + )) + + return (mfma_sync_desc, + mfma_sync_impl_integer) if in_dtype == "int8" else (mfma_sync_desc, + mfma_sync_impl_float) + + +def get_mfma_store_intrin(local_size=4, dtype="float32", scope="global"): + + index_map = shared_16x16_to_local_64x4_layout_C + + @T.prim_func + def mfma_store_desc(a: T.handle, c: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = T.meta_var(index_map(v0, v1)) + T.reads(C_warp[thread_id, local_id]) + T.writes(C[v0, v1]) + C[v0, v1] = C_warp[thread_id, local_id] + + @T.prim_func + def mfma_store_impl(a: T.handle, c: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=16) + C = T.match_buffer( + c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1]) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + T.evaluate( + T.tvm_mfma_store( + M_DIM, + N_DIM, + C.access_ptr("w"), + C_warp.data, + C_warp.elem_offset // (WARP_SIZE), + s0, + dtype=dtype, + )) + + return mfma_store_desc, mfma_store_impl + + +HIP_MFMA_fill_16x16_f32_INTRIN = "HIP_mfma_fill_16x16_f32" +TensorIntrin.register(HIP_MFMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 4)) + +HIP_MFMA_fill_16x16_i32_INTRIN = "HIP_mfma_fill_16x16_i32" +TensorIntrin.register(HIP_MFMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int", 4)) + +HIP_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN = "hip_mfma_load_16x16_a_shared_s8" +TensorIntrin.register(HIP_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN, + *get_mfma_load_intrin(16, "int8", "shared")) +HIP_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN = "hip_mfma_load_b_16x16_shared_s8" +TensorIntrin.register(HIP_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN, + *get_mfma_load_intrin(16, "int8", "shared", is_b=True)) + +HIP_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN = "hip_mfma_load_16x16_a_shared_f16" +TensorIntrin.register(HIP_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN, + *get_mfma_load_intrin(16, "float16", "shared")) +HIP_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN = "hip_mfma_load_b_16x16_shared_f16" +TensorIntrin.register(HIP_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN, + *get_mfma_load_intrin(16, "float16", "shared", is_b=True)) +HIP_MFMA_LOAD_16x16_B_TRANS_SHARED_f16_INTRIN = "hip_mfma_load_b_trans_16x16_shared_f16" +TensorIntrin.register(HIP_MFMA_LOAD_16x16_B_TRANS_SHARED_f16_INTRIN, + *get_mfma_load_intrin(16, "float16", "shared", is_b=True, transposed=True)) + +HIP_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN = "hip_mfma_load_16x4_a_shared_f32" +TensorIntrin.register(HIP_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN, + *get_mfma_load_intrin(4, "float32", "shared")) +HIP_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN = "hip_mfma_load_b_16x4_shared_f32" +TensorIntrin.register(HIP_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN, + *get_mfma_load_intrin(4, "float32", "shared", is_b=True)) +HIP_MFMA_LOAD_16x4_B_TRANS_SHARED_f32_INTRIN = "hip_mfma_load_b_trans_16x4_shared_f32" +TensorIntrin.register(HIP_MFMA_LOAD_16x4_B_TRANS_SHARED_f32_INTRIN, + *get_mfma_load_intrin(4, "float32", "shared", is_b=True, transposed=True)) + +HIP_MFMA_f32f32f32_INTRIN = "hip_mfma_f32f32f32" +TensorIntrin.register(HIP_MFMA_f32f32f32_INTRIN, *get_mfma_intrin(4, "float32", "float32")) + +HIP_MFMA_f16f16f32_INTRIN = "hip_mfma_f16f16f32" +TensorIntrin.register(HIP_MFMA_f16f16f32_INTRIN, *get_mfma_intrin(16, "float16", "float32")) + +HIP_MFMA_f16f16f32_TRANS_INTRIN = "hip_mfma_f16f16f32_trans" +TensorIntrin.register(HIP_MFMA_f16f16f32_TRANS_INTRIN, + *get_mfma_intrin(16, "float16", "float32", b_transposed=True)) + +HIP_MFMA_s8s8s32_INTRIN = "hip_mfma_s8s8s32" +TensorIntrin.register(HIP_MFMA_s8s8s32_INTRIN, *get_mfma_intrin(16, "int8", "int32")) + +HIP_MFMA_STORE_16x16_s32_INTRIN = "hip_mfma_store_16x16_s32" +TensorIntrin.register(HIP_MFMA_STORE_16x16_s32_INTRIN, *get_mfma_store_intrin(4, "int32", "global")) + +HIP_MFMA_STORE_16x16_f32_INTRIN = "hip_mfma_store_16x16_f32" +TensorIntrin.register(HIP_MFMA_STORE_16x16_f32_INTRIN, + *get_mfma_store_intrin(4, "float32", "global")) + + +def get_mfma_intrin_group( + load_scope: Literal["shared", "shared.dyn"] = "shared", + store_scope: Literal["global", "shared", "shared.dyn"] = "global", + a_dtype: Literal["float16", "int8", "bfloat16", "e4m3_float8", "e5m2_float8"] = "float16", + b_dtype: Literal["float16", "int8", "bfloat16", "e4m3_float8", "e5m2_float8"] = "float16", + out_dtype: Literal["float16", "float32", "int32"] = "float16", + trans_a: bool = False, + trans_b: bool = False, + not_use_mfma_store_intrinic: bool = True, + store_to_smem_dtype: Optional[Literal["float16", "float32", "int32"]] = None, +) -> Dict[str, str]: + """Get a group of intrinsics for mma tensor core with the given configurations + + Parameters + ---------- + load_scope : Literal["shared", "shared.dyn"] + The memory scope of the input buffer. + + store_scope : Literal["global", "shared", "shared.dyn"] + The memory scope of the result buffer. + + a_dtype : str + The dtype of the input matrix A. + + b_dtype : str + The dtype of the input matrix B. + + out_dtype : str + The output data dtype. + + trans_b : bool + Whether the input matrix B is transposed. + + not_use_mma_store_intrinic : bool + Whether to not use the mma_store intrinsic. If True, use BufferStore stmts to store the + result of mma. Otherwise, use mfma_store intrinsic. + + This is because if we use mfma_store intrinsic, during swizzling shared memory visits, our + rearrangement scheme will involve areas accessed by different mma_store calls. This makes + swizzling quite complex. But BufferStore will not face this problem. + + store_to_smem_dtype : Optional[Literal["float16", "float32", "int32"]] + The dtype that we use to store from register to shared memory. By default it is out_dtype. + + Returns + ------- + ret : Dict[str, str] + A group of tensor intrinsics. + """ + assert load_scope in ["shared", "shared.dyn"] + assert store_scope in ["global", "shared", "shared.dyn"] + assert a_dtype in ["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"] + assert b_dtype in ["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"] + assert out_dtype in ["float16", "float32", "int32"] + + shape = "16x16" + + dtype_mapping = { + "float16": "f16", + "bfloat16": "bf16", + "float32": "f32", + "int8": "i8", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + "int32": "i32", + } + a_dtype = dtype_mapping[a_dtype] + b_dtype = dtype_mapping[b_dtype] + out_dtype = dtype_mapping[out_dtype] + + # e.g. HIP_mfma_fill_16x16_f32 + init_intrin = f"HIP_mfma_fill_{shape}_{out_dtype}" + + # TODO change format + # e.g. hip_mfma_load_16x4_a_shared_f32, hip_mfma_load_16x16_a_shared_s8 + # trans_a = "_trans" if trans_a else "" + # trans_b = "_trans" if trans_b else "" + + load_a_intrin = f"hip_mfma_load_{shape}_a_shared_{a_dtype}" + # hip_mfma_load_b_trans_16x16_shared_f16 + load_b_intrin = f"hip_mfma_load_b_{shape}_shared_{b_dtype}" if trans_b is False else f"hip_mfma_load_b_trans_{shape}_shared_{b_dtype}" + + # e.g. hip_mfma_f32f32f32, hip_mfma_f16f16f32_trans + compute_intrin = (f"hip_mfma_{a_dtype}{b_dtype}{out_dtype}") if trans_b is False else ( + f"hip_mfma_{a_dtype}{b_dtype}{out_dtype}_trans") + + # e.g. hip_mfma_store_global_16x16_s32 + store_scope = store_scope.replace(".", "_") + store_to_smem_dtype = dtype_mapping[store_to_smem_dtype] if store_to_smem_dtype else out_dtype + store_intrin = f"hip_mfma_store_{shape}_{store_to_smem_dtype}" + + index_map_c = shared_16x16_to_local_64x4_layout_C + if a_dtype in ["f16", "bf16"]: + index_map_a = shared_16x16_to_local_64x4_layout_A + index_map_b = shared_16x16_to_local_64x4_layout_B if trans_b is False else shared_16x16_to_local_64x4_layout_A + elif a_dtype in ["i8", "e4m3", "e5m2"]: + index_map_a = shared_16x4_to_local_64x1_layout_A + index_map_b = shared_4x16_to_local_64x1_layout_B + else: + raise ValueError(f"Unsupported in_dtype: {a_dtype}") + + # micro kernel size, the order is [m, n, k] + micro_kernel: List[int] + if a_dtype in ["f16", "bf16"]: + micro_kernel = [16, 16, 16] + elif a_dtype in ["i8", "e4m3", "e5m2"]: + micro_kernel = [16, 16, 32] + else: + raise ValueError(f"Unsupported in_dtype: {a_dtype}") + + return { + "init": init_intrin, + "load_a": load_a_intrin, + "load_b": load_b_intrin, + "compute": compute_intrin, + "store": store_intrin, + "index_map": [index_map_a, index_map_b, index_map_c], + "micro_kernel": micro_kernel, + } + + +######## ROCrocwmma intrinsics ######## + + +def get_rocwmma_fragment_index(buffer, stride, m_dim, n_dim): + """Compute rocwmma fragment index using elem_offset of the buffer""" + frag_index_m = buffer.elem_offset // stride // m_dim + frag_index_n = buffer.elem_offset % stride // n_dim + + num_fragments_per_row = stride // n_dim + return frag_index_m * num_fragments_per_row + frag_index_n + + +def get_rocwmma_load_intrin( + m_dim: int, + n_dim: int, + k_dim: int, + dtype: str, + shared_scope: str, + is_b: bool, + is_col_major: bool, +) -> Tuple[PrimFunc, PrimFunc]: + """Generator of rocwmma_load intrins""" + rocwmma_fragment_scope = "wmma.matrix_{}".format("b" if is_b else "a") + layout = "col_major" if is_col_major else "row_major" + + @T.prim_func + def rocwmma_load_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=shared_scope) + C = T.match_buffer( + c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=rocwmma_fragment_scope) + with T.block("root"): + T.reads(A[0:m_dim, 0:n_dim]) + T.writes(C[0:m_dim, 0:n_dim]) + for i, j in T.grid(m_dim, n_dim): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + @T.prim_func + def rocwmma_load_impl(a: T.handle, c: T.handle) -> None: + s1 = T.int32() + s0 = T.int32() + d1 = T.int32() + d0 = T.int32() + A = T.match_buffer( + a, + (m_dim, n_dim), + dtype, + align=64, + offset_factor=16, + scope=shared_scope, + strides=[s1, s0], + ) + C = T.match_buffer( + c, + (m_dim, n_dim), + dtype, + align=64, + offset_factor=16, + scope=rocwmma_fragment_scope, + strides=[d1, d0], + ) + with T.block("root"): + T.reads(A[0:m_dim, 0:n_dim]) + T.writes(C[0:m_dim, 0:n_dim]) + T.evaluate( + T.tvm_load_matrix_sync( + C.data, + m_dim, + n_dim, + k_dim, + get_rocwmma_fragment_index(C, d1, m_dim, n_dim), + A.access_ptr("r"), + s1, + layout, + dtype="handle", + )) + + return rocwmma_load_desc, rocwmma_load_impl + + +def get_rocwmma_fill_intrin(m_dim: int, n_dim: int, k_dim: int, + dtype: str) -> Tuple[PrimFunc, PrimFunc]: + """Generator of rocwmma_fill intrins""" + zero = IntImm("int32", 0).astype(dtype) + + @T.prim_func + def rocwmma_fill_desc(c: T.handle) -> None: + C = T.match_buffer( + c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + T.reads() + T.writes(C[0:m_dim, 0:n_dim]) + for i, j in T.grid(m_dim, n_dim): + with T.block("init"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = zero + + @T.prim_func + def rocwmma_fill_impl(c: T.handle) -> None: + d1 = T.int32() + d0 = T.int32() + C = T.match_buffer( + c, + (m_dim, n_dim), + dtype, + align=64, + offset_factor=16, + scope="wmma.accumulator", + strides=[d1, d0], + ) + with T.block("root"): + T.reads() + T.writes(C[0:m_dim, 0:n_dim]) + T.evaluate( + T.tvm_fill_fragment( + C.data, + m_dim, + n_dim, + k_dim, + get_rocwmma_fragment_index(C, d1, m_dim, n_dim), + zero, + dtype="handle", + )) + + return rocwmma_fill_desc, rocwmma_fill_impl + + +def get_rocwmma_store_intrin(m_dim: int, n_dim: int, k_dim: int, dtype: str, + scope: str) -> Tuple[PrimFunc, PrimFunc]: + """Generator of rocwmma_store intrins""" + + @T.prim_func + def rocwmma_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope) + with T.block("root"): + T.reads(A[0:m_dim, 0:n_dim]) + T.writes(C[0:m_dim, 0:n_dim]) + for i, j in T.grid(m_dim, n_dim): + with T.block("store"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + @T.prim_func + def rocwmma_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.int32() + s0 = T.int32() + d1 = T.int32() + d0 = T.int32() + A = T.match_buffer( + a, + (m_dim, n_dim), + dtype, + align=64, + offset_factor=16, + scope="wmma.accumulator", + strides=[d1, d0], + ) + C = T.match_buffer( + c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope, strides=[s1, s0]) + with T.block("root"): + T.reads(A[0:m_dim, 0:n_dim]) + T.writes(C[0:m_dim, 0:n_dim]) + T.evaluate( + T.tvm_store_matrix_sync( + A.data, + m_dim, + n_dim, + k_dim, + get_rocwmma_fragment_index(A, d1, m_dim, n_dim), + C.access_ptr("w"), + s1, + "row_major", + dtype="handle", + )) + + return rocwmma_store_desc, rocwmma_store_impl + + +def get_rocwmma_sync_intrin(m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str, + b_transposed: bool) -> Tuple[PrimFunc, PrimFunc]: + """Generator of rocwmma_sync intrins""" + + def maybe_cast(v): + if in_dtype != out_dtype: + return Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + b_shape_0, b_shape_1 = maybe_swap(k_dim, n_dim) + + @T.prim_func + def rocwmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16, scope="wmma.matrix_a") + B = T.match_buffer( + b, + maybe_swap(k_dim, n_dim), + in_dtype, + align=64, + offset_factor=16, + scope="wmma.matrix_b", + ) + C = T.match_buffer( + c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16, scope="wmma.accumulator") + + with T.block("root"): + T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 0:b_shape_1]) + T.writes(C[0:m_dim, 0:n_dim]) + for i, j, k in T.grid(m_dim, n_dim, k_dim): + with T.block(""): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + B_index_0, B_index_1 = maybe_swap(vkk, vjj) + C[vii, + vjj] = C[vii, + vjj] + maybe_cast(A[vii, vkk]) * maybe_cast(B[B_index_0, B_index_1]) + + @T.prim_func + def rocwmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + a1 = T.int32() + a0 = T.int32() + b1 = T.int32() + b0 = T.int32() + c1 = T.int32() + c0 = T.int32() + + A = T.match_buffer( + a, + (m_dim, k_dim), + in_dtype, + align=64, + offset_factor=16, + scope="wmma.matrix_a", + strides=[a1, a0], + ) + B = T.match_buffer( + b, + maybe_swap(k_dim, n_dim), + in_dtype, + align=64, + offset_factor=16, + scope="wmma.matrix_b", + strides=[b1, b0], + ) + C = T.match_buffer( + c, + (m_dim, n_dim), + out_dtype, + align=64, + offset_factor=16, + scope="wmma.accumulator", + strides=[c1, c0], + ) + + with T.block("root"): + T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 0:b_shape_1]) + T.writes(C[0:m_dim, 0:n_dim]) + T.evaluate( + T.tvm_mma_sync( + C.data, + get_rocwmma_fragment_index(C, c1, m_dim, n_dim), + A.data, + get_rocwmma_fragment_index(A, a1, m_dim, k_dim), + B.data, + get_rocwmma_fragment_index(B, b1, b_shape_0, b_shape_1), + C.data, + get_rocwmma_fragment_index(C, c1, m_dim, n_dim), + dtype="handle", + )) + + return rocwmma_sync_desc, rocwmma_sync_impl + + +ROCWMMA_SYNC_16x16x16_f16f16f32_INTRIN = "rocwmma_sync_16x16x16_f16f16f32" +TensorIntrin.register( + ROCWMMA_SYNC_16x16x16_f16f16f32_INTRIN, + *get_rocwmma_sync_intrin(16, 16, 16, "float16", "float32", False), +) + +ROCWMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN = "rocwmma_sync_16x16x16_f16f16f32_trans" +TensorIntrin.register( + ROCWMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN, + *get_rocwmma_sync_intrin(16, 16, 16, "float16", "float32", True), +) + +ROCWMMA_SYNC_16x16x16_f16f16f16_INTRIN = "rocwmma_sync_16x16x16_f16f16f16" +TensorIntrin.register( + ROCWMMA_SYNC_16x16x16_f16f16f16_INTRIN, + *get_rocwmma_sync_intrin(16, 16, 16, "float16", "float16", False), +) + +ROCWMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN = "rocwmma_sync_16x16x16_f16f16f16_trans" +TensorIntrin.register( + ROCWMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN, + *get_rocwmma_sync_intrin(16, 16, 16, "float16", "float16", True), +) + +ROCWMMA_SYNC_16x16x16_s8s8s32_INTRIN = "rocwmma_sync_16x16x16_s8s8s32" +TensorIntrin.register( + ROCWMMA_SYNC_16x16x16_s8s8s32_INTRIN, + *get_rocwmma_sync_intrin(16, 16, 16, "int8", "int32", False), +) + +ROCWMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN = "rocwmma_sync_16x16x16_s8s8s32_trans" +TensorIntrin.register( + ROCWMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN, + *get_rocwmma_sync_intrin(16, 16, 16, "int8", "int32", True), +) + +ROCWMMA_LOAD_16x16x16_F16_A_INTRIN = "rocwmma_load_16x16x16_f16_a_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_A_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared", False, False), +) + +ROCWMMA_LOAD_16x16x16_F16_A_DYN_INTRIN = "rocwmma_load_16x16x16_f16_a_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_A_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, False), +) + +ROCWMMA_LOAD_16x16x16_F16_B_INTRIN = "rocwmma_load_16x16x16_f16_b_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_B_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared", True, False), +) + +ROCWMMA_LOAD_16x16x16_F16_B_DYN_INTRIN = "rocwmma_load_16x16x16_f16_b_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_B_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, False), +) + +ROCWMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "rocwmma_load_16x16x16_f16_a_trans_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared", False, True), +) + +ROCWMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN = "rocwmma_load_16x16x16_f16_a_trans_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, True), +) + +ROCWMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "rocwmma_load_16x16x16_f16_b_trans_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared", True, True), +) + +ROCWMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN = "rocwmma_load_16x16x16_f16_b_trans_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, True), +) + +ROCWMMA_LOAD_16x16x16_S8_A_INTRIN = "rocwmma_load_16x16x16_s8_a_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_A_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared", False, False), +) + +ROCWMMA_LOAD_16x16x16_S8_A_DYN_INTRIN = "rocwmma_load_16x16x16_s8_a_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_A_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, False), +) + +ROCWMMA_LOAD_16x16x16_S8_B_INTRIN = "rocwmma_load_16x16x16_s8_b_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_B_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared", True, False), +) + +ROCWMMA_LOAD_16x16x16_S8_B_DYN_INTRIN = "rocwmma_load_16x16x16_s8_b_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_B_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, False), +) + +ROCWMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "rocwmma_load_16x16x16_s8_a_trans_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared", False, True), +) + +ROCWMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN = "rocwmma_load_16x16x16_s8_a_trans_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, True), +) + +ROCWMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "rocwmma_load_16x16x16_s8_b_trans_shared" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared", True, True), +) + +ROCWMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN = "rocwmma_load_16x16x16_s8_b_trans_shared_dyn" +TensorIntrin.register( + ROCWMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN, + *get_rocwmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True), +) + +ROCWMMA_FILL_16x16x16_F32_INTRIN = "rocwmma_fill_16x16x16_f32" +TensorIntrin.register(ROCWMMA_FILL_16x16x16_F32_INTRIN, + *get_rocwmma_fill_intrin(16, 16, 16, "float32")) + +ROCWMMA_FILL_16x16x16_F16_INTRIN = "rocwmma_fill_16x16x16_f16" +TensorIntrin.register(ROCWMMA_FILL_16x16x16_F16_INTRIN, + *get_rocwmma_fill_intrin(16, 16, 16, "float16")) + +ROCWMMA_FILL_16x16x16_S32_INTRIN = "rocwmma_fill_16x16x16_s32" +TensorIntrin.register(ROCWMMA_FILL_16x16x16_S32_INTRIN, + *get_rocwmma_fill_intrin(16, 16, 16, "int32")) + +ROCWMMA_STORE_16x16x16_F32_SHARED_INTRIN = "rocwmma_store_16x16x16_f32_shared" +TensorIntrin.register(ROCWMMA_STORE_16x16x16_F32_SHARED_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "float32", "shared")) + +ROCWMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN = "rocwmma_store_16x16x16_f32_shared_dyn" +TensorIntrin.register( + ROCWMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "float32", "shared.dyn"), +) + +ROCWMMA_STORE_16x16x16_F16_SHARED_INTRIN = "rocwmma_store_16x16x16_f16_shared" +TensorIntrin.register(ROCWMMA_STORE_16x16x16_F16_SHARED_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "float16", "shared")) + +ROCWMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN = "rocwmma_store_16x16x16_f16_shared_dyn" +TensorIntrin.register( + ROCWMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "float16", "shared.dyn"), +) + +ROCWMMA_STORE_16x16x16_S32_SHARED_INTRIN = "rocwmma_store_16x16x16_s32_shared" +TensorIntrin.register(ROCWMMA_STORE_16x16x16_S32_SHARED_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "int32", "shared")) + +ROCWMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN = "rocwmma_store_16x16x16_s32_shared_dyn" +TensorIntrin.register( + ROCWMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "int32", "shared.dyn"), +) + +ROCWMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "rocwmma_store_16x16x16_f32_global" +TensorIntrin.register(ROCWMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "float32", "global")) + +ROCWMMA_STORE_16x16x16_F16_GLOBAL_INTRIN = "rocwmma_store_16x16x16_f16_global" +TensorIntrin.register(ROCWMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "float16", "global")) + +ROCWMMA_STORE_16x16x16_S32_GLOBAL_INTRIN = "rocwmma_store_16x16x16_s32_global" +TensorIntrin.register(ROCWMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, + *get_rocwmma_store_intrin(16, 16, 16, "int32", "global")) diff --git a/bitblas/gpu/matmul.py b/bitblas/gpu/matmul.py index 0cfd65dc4..e09a6fec7 100644 --- a/bitblas/gpu/matmul.py +++ b/bitblas/gpu/matmul.py @@ -28,6 +28,7 @@ MatmulInt8Tensorization, MatmulTensorizationWMMA, ) +from .matmul_mfma import MatmulTensorizationMFMA from functools import reduce import logging @@ -54,7 +55,7 @@ class Config: def get_configs(self, target: Target) -> Config: """Get the schedule config for the target""" - if target.kind.name == "cuda" or target.kind.name == "rocm": + if target.kind.name in {"cuda", "rocm", "hip"}: return Matmul.Config( block_size_x=8, block_size_y=16, @@ -109,10 +110,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if sch is None: return None - # Step 1. Check Tensor Core support + # Step 1. Check hardware supports tensorization. # Tensorization config: # If any value of I, J, K is fixed and less than this threshold, # tensorization rule will not be applied. + #TODO check matrix core support, now there is a trick: MI250 can use Matrix core. minimal_tensorize_threshold = 64 block_stmt = sch.get(main_block) if target.kind.name == "cuda" and utils.get_sm_version(target) >= 70: @@ -138,6 +140,23 @@ def apply( # pylint: disable=too-many-locals,missing-docstring tensorize_sch = MatmulTensorizationWMMA().apply(func, target, _) if tensorize_sch is not None: return tensorize_sch + elif target.kind.name == "hip": + apply_tensorization: bool = True + # the batch dimension is not taken into consideration. + # Analyze read/write buffers and choose correct tensorizer: int8 or fp16. + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + if in_dtype not in ["int8", "float16"]: + apply_tensorization = False + for item_var in block_stmt.iter_vars[1:]: + extent = item_var.dom.extent + if isinstance(extent, + tir.expr.IntImm) and extent.value <= minimal_tensorize_threshold: + apply_tensorization = False + if apply_tensorization: + # For MI250 + tensorize_sch = MatmulTensorizationMFMA().apply(func, target, _) + if tensorize_sch is not None: + return tensorize_sch # Step 2. Get schedule config. config = self.get_configs(target) diff --git a/bitblas/gpu/matmul_mfma.py b/bitblas/gpu/matmul_mfma.py new file mode 100644 index 000000000..af564f245 --- /dev/null +++ b/bitblas/gpu/matmul_mfma.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Optional + +from tvm import tir + +from ..base.roller import Hint +from ..base import analysis +from .base import GPUScheduleRule +from .matmul_analysis import get_reduction_blocks + + +def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 + + def index_map_3d(i, j): + return ( + i // l, + j // r, + *index_map(i % l, j % r), + ) + + return index_map_3d + + +def get_index_map_5d(index_map): + """ + for layout transformed gemm, the index map should be 5d + """ + + def index_map_5d(i, j, ii, jj): + return ( + i, + j, + *index_map(ii, jj), + ) + + return index_map_5d + + +def get_warp_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 + if is_5d: + return get_index_map_5d(index_map) + return get_index_map_3d(index_map, l, r) + + +class MatmulTensorizationMFMA(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply_config( + self, + func: tir.PrimFunc, + config: Hint, + ) -> Optional[tir.Schedule]: + + from bitblas.gpu.intrin.hip import ( + get_mfma_intrin_group,) + + is_cross_thread_reduce = ( + hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None) + block_reduction_depth = config.block_reduction_depth if is_cross_thread_reduce else 1 + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + + #cache_write_required = True + + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mfma_intrin_group( + load_scope=shared_scope, + store_scope="global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + not_use_mfma_store_intrinic=False, + ) + + # Start schedule + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + reduce_k = block_reduction_depth + chunk = int(config.rstep[0] / reduce_k) + + #tensor core intrinsic size + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): + return (r, l) if trans else (l, r) + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + # matrix core not support swizzle + + warp_size = 64 + + block = main_block + + (i, j, k) = sch.get_loops(block) + by, i = sch.split(i, factors=[None, config.block[0]]) + bx, j = sch.split(j, factors=[None, config.block[1]]) + bk, k = sch.split(k, factors=[None, (chunk * micro_size_k)]) + + sch.reorder(by, bx, bk, i, j, k) + + sch.bind(bx, "blockIdx.x") + sch.bind(by, "blockIdx.y") + + block_tz, block_inner_i = sch.split(i, factors=[block_row_warps, None]) + + block_ty, block_inner_j = sch.split(j, factors=[block_col_warps, None]) + + sch.reorder(block_tz, block_ty, bk, block_inner_i, block_inner_j, k) + + sch.bind(block_tz, "threadIdx.z") + sch.bind(block_ty, "threadIdx.y") + + #schedule the shared memory + def fetch_to_shared(block, idx, vec_len=8, can_swizzle=False, is_smooth=False, reduce_k=1): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, bk, preserve_unit_loops=True) + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + + _, f_0, f_1, f_2, f_3 = sch.split( + fused, factors=[None, block_row_warps, block_col_warps, warp_size, vec_len]) + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.bind(f_0, "threadIdx.z") + sch.vectorize(f_3) + + # fetch A,B to shared + # 0->A, 1->B + fetch_to_shared(main_block, 0) + fetch_to_shared(main_block, 1) + + # blockize for mma tensorize + block_inner_i, block_inner_i_tc = sch.split(block_inner_i, factors=[None, micro_size_x]) + block_inner_j, block_inner_j_tc = sch.split(block_inner_j, factors=[None, micro_size_y]) + k, k_tc = sch.split(k, factors=[None, micro_size_k]) + + if intrin_info.trans_b: + sch.reorder(k, block_inner_i, block_inner_j, block_inner_i_tc, block_inner_j_tc, k_tc) + else: + sch.reorder(block_inner_i, block_inner_j, k, block_inner_i_tc, block_inner_j_tc, k_tc) + + A_mat = sch.cache_read(main_block, 0, "warp") + B_mat = sch.cache_read(main_block, 1, "warp") + sch.compute_at(A_mat, k) + sch.compute_at(B_mat, k) + + C_store = sch.cache_write(main_block, 0, "warp") + + sch.reverse_compute_at(C_store, block_ty) + + i, j = sch.get_loops(C_store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + def tile_wmma_fragment(block_read, height, width): + i, j = sch.get_loops(block_read)[-2:] + i0, i1 = sch.split(i, factors=[None, height]) + j0, j1 = sch.split(j, factors=[None, width]) + sch.reorder(i0, j0, i1, j1) + return i1 + + if intrin_info.trans_b: + a_loop_warp = tile_wmma_fragment(A_mat, micro_size_x, micro_size_k) + b_loop_warp = tile_wmma_fragment(B_mat, micro_size_k, micro_size_y) + else: + a_loop_warp, _ = sch.get_loops(A_mat)[-2:] + b_loop_warp, _ = sch.get_loops(B_mat)[-2:] + + block_init_c = sch.decompose_reduction(main_block, bk) + + # Tensorization by hardware intrinsics + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout(A_mat, ("write", 0), + get_warp_index_map(index_map_a, *b_lr, intrin_info.inter_transform_a)) + + sch.transform_layout( + B_mat, + ("write", 0), + get_warp_index_map(index_map_b, *a_lr, intrin_info.inter_transform_b), + ) + + sch.transform_layout( + C_store, + ("read", 0), + get_warp_index_map(index_map_c, is_5d=False), + ) + + sch.tensorize(a_loop_warp, intrin_group["load_a"]) + sch.tensorize(b_loop_warp, intrin_group["load_b"]) + + sch.tensorize(block_inner_i_tc, intrin_group["compute"]) + + sch.tensorize(sch.get_loops(block_init_c)[-2], intrin_group["init"]) + + sch.tensorize(sch.get_loops(C_store)[-2], intrin_group["store"]) + + return sch diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 327cd1a3d..dafdc8173 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -6,6 +6,7 @@ import operator from functools import reduce from bitblas.base.arch.cuda import CUDA +from bitblas.base.arch.cdna import CDNA from bitblas.base.roller.hint import Hint from typing import Any, Literal, Optional, Tuple, Union from ..operator import OperatorConfig, Operator, OPExecutorCPU, BaseKernelNameGenerator @@ -379,8 +380,8 @@ def __init__( ) target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") + if target.kind.name not in ("cuda", "hip"): + raise ValueError("Currently only support cuda and hip target") self.dispatch_tir(target, from_database, source_format, enable_tuning) @@ -390,7 +391,10 @@ def dispatch_tir(self, source_format: str = "uint", enable_tuning: bool = True): '''Dispatch the tir script implementation''' - self.arch = CUDA(target) + if (target.kind.name == "cuda"): + self.arch = CUDA(target) + elif (target.kind.name == "hip"): + self.arch = CDNA(target) if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} diff --git a/bitblas/ops/ladder_permutate/__init__.py b/bitblas/ops/ladder_permutate/__init__.py index c3406f6a0..d8366a332 100644 --- a/bitblas/ops/ladder_permutate/__init__.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -37,7 +37,7 @@ def __init__( super().__init__(name, config, target) target = self.target - if target.kind.name == "cuda": + if target.kind.name == "cuda" or target.kind.name == "hip": self.scheduled_ir_module = self.apply_default_schedule(self.ir_module, target) if enable_tuning: self.hardware_aware_finetune() diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 94ea042cd..ba3927005 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -177,6 +177,52 @@ def _build_runtime_module(self, target: Target): def tvm_callback_cuda_postproc(code, _): return self.post_process(code) + try: + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": True, + "tir.disable_cse_tir": True, + **(self.pass_context if self.pass_context else {}), + }): + if self.is_tir_backend(): + rt_mod = tvm.build(self.scheduled_ir_module, target=target) + elif self.is_tilelang_backend(): + # check only have one function in the module + if len(self.scheduled_ir_module.functions) > 1: + raise ValueError("Only support one function in the module") + tl_prim_func = list(self.scheduled_ir_module.functions.values())[0] + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": True, + "tir.disable_cse_tir": True, + **(self.pass_context if self.pass_context else {}) + }): + rt_mod = tl.lower(tl_prim_func, target=target, runtime_only=True) + else: + raise ValueError(f"Unsupported backend: {self.backend}") + except Exception as build_runtime_error: # noqa: F841 + error_message = str(build_runtime_error) + # Truncate only if the message exceeds the maximum length + if len(error_message) > MAX_ERROR_MESSAGE_LENGTH: + truncated_message = f"{error_message[-MAX_ERROR_MESSAGE_LENGTH:]} [...]" + else: + truncated_message = error_message + + logger.debug( + BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format( + self.__class__.__name__, + target, + "optimized", + truncated_message, + )) + elif self.arch.platform == "CDNA": + if self.scheduled_ir_module is None: + return None + + @tvm.register_func(func_name="tvm_callback_hip_postproc", override=True) + def tvm_callback_hip_postproc(code, _): + return self.post_process(code) + try: with tvm.transform.PassContext( config={ @@ -216,7 +262,7 @@ def tvm_callback_cuda_postproc(code, _): truncated_message, )) else: - # For non-CUDA platforms or when no optimized function is available, build with the primary function + # For non-CUDA and non-hip platforms or when no optimized function is available, build with the primary function rt_mod = tvm.build(self.prim_func, target=target, name=self.name) # If the runtime module was successfully built, set up for evaluation @@ -226,7 +272,7 @@ def tvm_callback_cuda_postproc(code, _): self.time_evaluator = rt_mod.time_evaluator( rt_mod.entry_name, self.arch.device, number=10) self.torch_func = to_pytorch_func(rt_mod) - if self.arch.platform == "CUDA": + if self.arch.platform in {"CUDA", "CDNA"}: try: is_dynamic = ( self.dynamic_range is not None and diff --git a/testing/python/amd/test_backend_hip_wrapper_matmul.py b/testing/python/amd/test_backend_hip_wrapper_matmul.py new file mode 100644 index 000000000..235c4d292 --- /dev/null +++ b/testing/python/amd/test_backend_hip_wrapper_matmul.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level +from bitblas.builder.wrapper import TIRWrapper +import tvm + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +def matmul_backend_code_wrap( + M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_bias, +): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + with_bias=with_bias, + ) + target = tvm.target.Target("hip") + matmul = Matmul(config=matmul_config, target=target, enable_tuning=False) + backend = TIRWrapper(arch=matmul.arch) + backend.assign_optimized_module(matmul.scheduled_ir_module) + is_dynamic = ( + matmul.dynamic_range is not None and len(matmul.scheduled_ir_module.functions) > 1) + wrapped_code = backend.wrap(matmul.get_source(kenrel_only=True), is_dynamic=is_dynamic) + assert "void call" in wrapped_code + + +@tvm.testing.requires_rocm +def test_matmul_transform_weight(): + matmul_backend_code_wrap(128, 128, 128, "float16", "float16", "float16", "float16", False) + matmul_backend_code_wrap(1, 256, 256, "float16", "float16", "uint4", "float16", False) + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/amd/test_matmul_mfma_schedule_trans_b.py b/testing/python/amd/test_matmul_mfma_schedule_trans_b.py new file mode 100644 index 000000000..1306c3442 --- /dev/null +++ b/testing/python/amd/test_matmul_mfma_schedule_trans_b.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import tvm +from bitblas.ops.general_matmul.tirscript import ( + matmul_select_implementation,) +import logging +from bitblas import set_log_level +import numpy as np + +np.random.seed(0) + +set_log_level(logging.DEBUG) + + +# fmt: off +def assert_correctness_with_block_reduce( + M=256, + N=256, + K=256, + in_dtype="float16", + out_dtype="float32", + accum_dtype="float32", + propagate_a=0, + propagate_b=0, +): + matmul_func = matmul_select_implementation( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout="nt")["main"] + target = tvm.target.Target("hip") + intrin_info = bitblas.base.hint.IntrinInfo( + in_dtype=in_dtype, + out_dtype=out_dtype, + trans_b=True, + input_transform_kind=propagate_a, + weight_transform_kind=propagate_b, + ) + arch = bitblas.base.CDNA(target=target) + ref_sch = bitblas.gpu.MatmulTensorizationMFMA().apply_config( + matmul_func, + config=bitblas.base.Hint.from_dict({ + "arch": arch, + "block": [128, 128], + "warp": [64, 64], + "rstep": [32], + "chunk": [2], + "block_reduction_depth": 16, + "pipeline_stage": 2, + "use_async": True, + "intrin_info": intrin_info, + "shared_scope": "shared", + "vectorize": { + "b": 8, + "a": 8 + }, + }), + ) + + with tvm.transform.PassContext(): + ref_rt_mod = tvm.build(ref_sch.mod, target=target) + + ctx = tvm.rocm(0) + np.random.seed(0) + a_np = (np.random.rand(M, K)).astype("float16") + print(a_np) + b_np = (np.random.rand(N, K)).astype("float16") + + rocm_a = tvm.nd.array((a_np).astype("float16"), ctx) + rocm_b = tvm.nd.array((b_np).astype("float16"), ctx) + rocm_c = tvm.nd.array(np.zeros((M, N)).astype("float32"), ctx) + + ref_rt_mod(rocm_a, rocm_b, rocm_c) + + c_np = rocm_c.numpy() + np.testing.assert_allclose( + c_np, np.matmul(a_np.astype("float32"), b_np.astype("float32").T), rtol=1e-2, atol=1e-2 + ) + print(c_np) + print(np.matmul(a_np.astype("float32"), b_np.astype("float32").T)) + +@tvm.testing.requires_rocm +def test_assert_correctness_with_block_reduce(): + assert_correctness_with_block_reduce( + M=256, + N=256, + K=256, + in_dtype="float16", + out_dtype="float32", + accum_dtype="float32", + propagate_a=0, + propagate_b=0) + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/tutorials/ladder_from_onnx.py b/tutorials/ladder_from_onnx.py index ce07e7265..2cc981863 100644 --- a/tutorials/ladder_from_onnx.py +++ b/tutorials/ladder_from_onnx.py @@ -9,7 +9,7 @@ from tvm.contrib import graph_executor from ladder.utils import write_mod import os -import torch +import torch # noqa: F401 import logging ladder.set_log_level(logging.INFO) @@ -26,52 +26,56 @@ parser.add_argument('--cublas', action="store_true") parser.add_argument('--cudnn', action="store_false") parser.add_argument('--nhwc', action="store_false") -parser.add_argument('--async_propagation', action="store_true", help="Use async propagation and async instructions, which should be only enabled on data center GPUs with async copy instructions.", default=False) -parser.add_argument("--prebuilt_path", type=str, default=None, help="Path to the prebuilt model. If set, the script will run from the prebuilt model.") +parser.add_argument( + '--async_propagation', + action="store_true", + help="Use async propagation and async instructions, which should be only enabled on data center GPUs with async copy instructions.", + default=False) +parser.add_argument( + "--prebuilt_path", + type=str, + default=None, + help="Path to the prebuilt model. If set, the script will run from the prebuilt model.") parser.add_argument("--fast_decoding", action="store_false", help="Enable fast decoding mode.") args = parser.parse_args() + def run(prefix, arch, async_propagate): if ".onnx" in prefix: onnx_model = onnx.load(prefix) else: onnx_model = onnx.load(osp.join(prefix, "model.onnx")) - mod, params = relay.frontend.from_onnx( - onnx_model, convert_config={"use_welder_matmul": False}) + mod, params = relay.frontend.from_onnx(onnx_model, convert_config={"use_welder_matmul": False}) write_mod(mod, log_path, "load_from_onnx") if args.nhwc: - # must convert bias_add -> broadcast_add to propogate the layout + # must convert bias_add -> broadcast_add to propagate the layout mod = relay.transform.InferType()(mod) mod = relay.transform.CanonicalizeOps()(mod) write_mod(mod, log_path, "CanonicalizeOps") - mod = relay.transform.ConvertLayout( - {"nn.conv2d": ["NHWC", "default"]})(mod) + mod = relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(mod) write_mod(mod, log_path, "ConvertLayout") mod = relay.transform.FoldConstant()(mod) write_mod(mod, log_path, "FoldConstant") mod = ladder.relay.transform.WelderExprRewrite(enable_softmax=True)(mod) write_mod(mod, log_path, "expr_rewrite") - + if args.cudnn: from tvm.relay.op.contrib.cudnn import pattern_table - seq = tvm.transform.Sequential( - [ - relay.transform.InferType(), - relay.transform.MergeComposite(pattern_table()), - relay.transform.AnnotateTarget("cudnn"), - relay.transform.PartitionGraph(bind_constants=False), - relay.transform.InferType(), - ] - ) + seq = tvm.transform.Sequential([ + relay.transform.InferType(), + relay.transform.MergeComposite(pattern_table()), + relay.transform.AnnotateTarget("cudnn"), + relay.transform.PartitionGraph(bind_constants=False), + relay.transform.InferType(), + ]) mod = seq(mod) - mod = ladder.relay.transform.LadderConvImplicitGemm( - use_async_propagation=async_propagate)(mod) + mod = ladder.relay.transform.LadderConvImplicitGemm(use_async_propagation=async_propagate)(mod) write_mod(mod, log_path, "LadderConvImplicitGemm") - mod = ladder.relay.transform.LadderPerfectGemmTransform( - use_async_propagation=async_propagate)(mod) + mod = ladder.relay.transform.LadderPerfectGemmTransform(use_async_propagation=async_propagate)( + mod) write_mod(mod, log_path, "LadderPerfectGemmTransform") mod = ladder.relay.transform.WelderConvImplicitGemm()(mod) write_mod(mod, log_path, "WelderConvImplicitGemm") @@ -83,15 +87,13 @@ def run(prefix, arch, async_propagate): write_mod(mod, log_path, "LadderRewriteInceptionLayout") if args.cublas: from tvm.relay.op.contrib.cublas import pattern_table - seq = tvm.transform.Sequential( - [ - relay.transform.InferType(), - relay.transform.MergeComposite(pattern_table()), - relay.transform.AnnotateTarget("cublas"), - relay.transform.PartitionGraph(bind_constants=False), - relay.transform.InferType(), - ] - ) + seq = tvm.transform.Sequential([ + relay.transform.InferType(), + relay.transform.MergeComposite(pattern_table()), + relay.transform.AnnotateTarget("cublas"), + relay.transform.PartitionGraph(bind_constants=False), + relay.transform.InferType(), + ]) mod = seq(mod) write_mod(mod, log_path, "cublas_partition") mod = relay.transform.DeadCodeElimination()(mod) @@ -110,7 +112,9 @@ def run(prefix, arch, async_propagate): mod = ladder.relay.transform.AnnotateFastDecoding()(mod) write_mod(mod, log_path, "AnnotateFastDecoding") - mod = ladder.relay.transform.WelderTunePass(arch, topk=40,save_perf_log="./debug_group_info")(mod) + mod = ladder.relay.transform.WelderTunePass( + arch, topk=40, save_perf_log="./debug_group_info")( + mod) write_mod(mod, log_path, "WelderTunePass") factory = relay.build(mod, arch.target, params=params) @@ -119,8 +123,7 @@ def run(prefix, arch, async_propagate): f.write(factory.get_graph_json()) with open(osp.join(log_path, "graph.params"), "wb") as f_params: f_params.write(tvm.runtime.save_param_dict(factory.get_params())) - lib = ladder.relay.update_lib( - factory.get_lib(), arch, osp.join(log_path, "model.so")) + lib = ladder.relay.update_lib(factory.get_lib(), arch, osp.join(log_path, "model.so")) rt_mod = graph_executor.create(factory.get_graph_json(), lib, tvm.cuda(0)) rt_mod.set_input(**factory.get_params())