diff --git a/3rdparty/tvm b/3rdparty/tvm index 1cc769cd7..321f4151d 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 1cc769cd75cc9a497c5077cb71e68d7e60225f28 +Subproject commit 321f4151dbe2ed57fcd722530c79f3658ec013fd diff --git a/bitblas/base/arch/__init__.py b/bitblas/base/arch/__init__.py index 27581d2f8..ad6080914 100644 --- a/bitblas/base/arch/__init__.py +++ b/bitblas/base/arch/__init__.py @@ -15,3 +15,18 @@ def get_arch(target: tvm.target.Target) -> TileDevice: return CDNA(target) else: raise ValueError(f"Unsupported target: {target.kind.name}") + + +def is_ampere_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(isinstance(arch, CUDA)) + conditions.append(arch.sm_version >= 80) + return all(conditions) + + +def is_volta_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(isinstance(arch, CUDA)) + conditions.append(arch.sm_version >= 70) + conditions.append(arch.sm_version < 80) + return all(conditions) diff --git a/bitblas/ops/base_scheduler.py b/bitblas/ops/base_scheduler.py index f18c98026..f0e35662c 100644 --- a/bitblas/ops/base_scheduler.py +++ b/bitblas/ops/base_scheduler.py @@ -1,10 +1,12 @@ from tvm import IRModule from tvm.tir import PrimFunc -from typing import Union, Callable +from typing import Union, Callable, List from dataclasses import dataclass, field from tvm.tl.transform import Simplify from abc import ABC, abstractmethod from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint +from bitblas.tl.base_hint import BaseTLHint # Decorator to simplify the output of a function @@ -54,7 +56,7 @@ def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]): return stmt @abstractmethod - def with_default_config(self): + def with_default_config(self) -> PrimFunc: pass @abstractmethod @@ -65,6 +67,10 @@ def apply_config( ): pass + def serialze_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]: + # Convert Roller Hints to TileLang Hints + raise NotImplementedError + @property def common_header(self): # TODO(lei): For HIP Backend it should be different diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index dafdc8173..c62e785b5 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -5,8 +5,6 @@ from tvm.target import Target import operator from functools import reduce -from bitblas.base.arch.cuda import CUDA -from bitblas.base.arch.cdna import CDNA from bitblas.base.roller.hint import Hint from typing import Any, Literal, Optional, Tuple, Union from ..operator import OperatorConfig, Operator, OPExecutorCPU, BaseKernelNameGenerator @@ -290,24 +288,25 @@ def generate(self, hint=None) -> str: precision_str = (f"{A_dtype}x{W_dtype}") kernel_name = "_".join([kernel_name, shape_str, precision_str]) - - # if config.with_scaling: - # kernel_name += "Scale" - - # if config.with_zeros: - # if config.zeros_mode == "original": - # kernel_name += "OriginalZeros" - # elif config.zeros_mode == "rescale": - # precision_str += "RescaleZeros" - # elif config.zeros_mode == "quantized": - # precision_str += "QuantizedZeros" - # else: - # raise ValueError(f"Unsupported zeros mode: {config.zeros_mode}") - - # if config.propagate_a is not TransformKind.NonTransform: - # kernel_name += f"_pa{config.propagate_a.value}" - # if config.propagate_b is not TransformKind.NonTransform: - # kernel_name += f"_pb{config.propagate_b.value}" + ''' + if config.with_scaling: + kernel_name += "Scale" + + if config.with_zeros: + if config.zeros_mode == "original": + kernel_name += "OriginalZeros" + elif config.zeros_mode == "rescale": + precision_str += "RescaleZeros" + elif config.zeros_mode == "quantized": + precision_str += "QuantizedZeros" + else: + raise ValueError(f"Unsupported zeros mode: {config.zeros_mode}") + + if config.propagate_a is not TransformKind.NonTransform: + kernel_name += f"_pa{config.propagate_a.value}" + if config.propagate_b is not TransformKind.NonTransform: + kernel_name += f"_pb{config.propagate_b.value}" + ''' kernel_name = "_".join([kernel_name, self.serialize_hint(hint)]) assert self.is_valid(kernel_name), "Kernel name invalid" @@ -390,16 +389,12 @@ def dispatch_tir(self, from_database: bool = False, source_format: str = "uint", enable_tuning: bool = True): - '''Dispatch the tir script implementation''' - if (target.kind.name == "cuda"): - self.arch = CUDA(target) - elif (target.kind.name == "hip"): - self.arch = CDNA(target) if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} - self.ir_module["main"] = self.ir_module["main"].with_attrs( - {"opt_shapes": self.dynamic_range}) + if self.is_tir_backend(): + self.ir_module["main"] = self.ir_module["main"].with_attrs( + {"opt_shapes": self.dynamic_range}) else: self.dynamic_range = None @@ -600,6 +595,7 @@ def _select_implementation(self): def _select_scheduler(self): if is_native_compute(self.A_dtype, self.W_dtype): return consistent_scheduler( + arch=self.arch, M=self.M, N=self.N, K=self.K, @@ -613,6 +609,7 @@ def _select_scheduler(self): ) else: return weight_dequantize_scheduler( + arch=self.arch, M=self.M, N=self.N, K=self.K, diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index d3c9b38aa..442b4c6f2 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -12,7 +12,7 @@ ) from .matmul_tensorcore import ( - MatmulScheduler, # noqa: F401 + MatmulBlockScheduler, # noqa: F401 MatmulFineGrainScheduler, # noqa: F401 MatmulWeightPropagationScheduler, # noqa: F401 ) @@ -22,6 +22,11 @@ MatmulINT4WeightPropagationScheduler, # noqa: F401 ) +from bitblas.base.roller import TileDevice +from bitblas.base.arch import ( + is_ampere_arch, + is_volta_arch, +) from bitblas.ops.common import TransformKind from typing import Union @@ -40,7 +45,52 @@ def is_non_transform_kind(kind) -> bool: return kind == TransformKind.NonTransform -def select_scheduler( +def volta_select_schduler( + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: Union[int, TransformKind] = TransformKind.NonTransform, + propagate_b: Union[int, TransformKind] = TransformKind.NonTransform, +): + trans_A, trans_B = parse_layout(layout) + if isinstance(propagate_a, int): + propagate_a = TransformKind(propagate_a) + if isinstance(propagate_b, int): + propagate_b = TransformKind(propagate_b) + + def check_if_not_supported(): + conditions = [True] + conditions.append(propagate_a == TransformKind.NonTransform) + conditions.append(propagate_b == TransformKind.NonTransform) + conditions.append(trans_A is False) + conditions.append(trans_B is True) + conditions.append(in_dtype in ["int8", "float16", "float32"]) + conditions.append(accum_dtype in ["int32", "float32"]) + return all(conditions) + + if not check_if_not_supported(): + raise ValueError(f"Unsupported configuration: {layout}, {propagate_a}, {propagate_b}") + + Scheduler = MatmulFineGrainSIMTScheduler + return Scheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + ) + + +def ampere_select_scheduler( M=None, N=16384, K=16384, @@ -60,8 +110,6 @@ def select_scheduler( propagate_a = TransformKind(propagate_a) if isinstance(propagate_b, int): propagate_b = TransformKind(propagate_b) - if with_bias: - raise NotImplementedError trans_A, trans_B = parse_layout(layout) @@ -102,6 +150,7 @@ def is_int4_dtype(dtype): in_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, + with_bias=with_bias, ) if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): Scheduler = MatmulFineGrainScheduler if not is_int4_dtype( @@ -115,9 +164,10 @@ def is_int4_dtype(dtype): in_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, + with_bias=with_bias, ) elif can_apply_block_scheduler(propagate_a, propagate_b): - return MatmulScheduler( + return MatmulBlockScheduler( M=M, N=N, K=K, @@ -126,6 +176,50 @@ def is_int4_dtype(dtype): in_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, + with_bias=with_bias, ) else: raise ValueError(f"Unsupported configuration: {layout}, {propagate_a}, {propagate_b}") + + +def select_scheduler( + arch: TileDevice, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: Union[int, TransformKind] = TransformKind.NonTransform, + propagate_b: Union[int, TransformKind] = TransformKind.NonTransform, +): + if is_ampere_arch(arch): + return ampere_select_scheduler( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + layout=layout, + propagate_a=propagate_a, + propagate_b=propagate_b, + ) + elif is_volta_arch(arch): + return volta_select_schduler( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + layout=layout, + propagate_a=propagate_a, + propagate_b=propagate_b, + ) + else: + raise ValueError(f"Unsupported arch: {arch.name}") diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index 76d756e96..03373b6fe 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -1,16 +1,24 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm -from typing import Optional +from typing import Optional, List from bitblas.ops.base_scheduler import BaseScheduler +import tvm.tl.language as T +from tvm import DataType +from tvm.tir import PrimFunc from dataclasses import dataclass +from bitblas.base.utils import get_roller_hints_from_func +from bitblas.ops.general_matmul.tirscript import (matmul_select_implementation) +from bitblas.base.arch import TileDevice +from bitblas.tl.base_hint import BaseTLHint +from bitblas.base.roller.hint import Hint @dataclass -class MatmulFineGrainSIMTScheduler(BaseScheduler): - # Fine-grained matrix multiplication scheduler - # Allows for more detailed configuration. +class MatmulSIMTBaseScheduler(BaseScheduler): + # Base class for matrix multiplication scheduler + # Contains the basic configuration for matrix multiplication # Operation Configuration M: Optional[int] = None @@ -21,32 +29,242 @@ class MatmulFineGrainSIMTScheduler(BaseScheduler): trans_A: bool = False trans_B: bool = True accum_dtype: str = "float16" + with_bias: bool = False + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + + # Simple TIR Compute Expression + ir_module = matmul_select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + ) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=False, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + return self.serialze_hints_to_configs(roller_hints) + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + return self.get_roller_configs(arch, topk) + + # check if required shared memory cache + def check_require_cache(self) -> bool: + with_bias = self.with_bias + + conditions: List[bool] = [] + conditions.append(False) + # Bias Add should be done in shared memory + conditions.append(with_bias) + return any(conditions) # Always set to False Currently + + +@dataclass +class MatmulFineGrainSIMTScheduler(MatmulSIMTBaseScheduler): + # Fine-grained matrix multiplication scheduler + # Allows for more detailed configuration. # Tensor Core Warp Configuration - block_row_warps: int = 2 - block_col_warps: int = 2 - warp_row_tiles: int = 32 - warp_col_tiles: int = 32 - chunk: int = 32 # Usually determines the K-dimension split size + block_size_x: int = 8 + block_size_y: int = 8 + thread_row_tiles: int = 16 + thread_col_tiles: int = 16 + chunk: int = 16 # Usually determines the K-dimension split size + + class TLHint(BaseTLHint): + + def __init__(self): + super().__init__() + + @classmethod + def from_roller_hint(cls, hint: Hint): + tl_hint = cls() + for key, value in hint.__dict__.items(): + setattr(tl_hint, key, value) + + block_row_warps = hint.block[0] // (hint.thread[0] * hint.step[0]) + block_col_warps = hint.block[1] // (hint.thread[1] * hint.step[1]) + thread_row_tiles = hint.thread[0] // (hint.step[0] * 2) + thread_col_tiles = hint.thread[1] // (hint.step[1] * 2) + vthread_row_tiles = (hint.step[0] * 2) # expand vtrhead to avoid load band conflict + vthread_col_tiles = (hint.step[1] * 2) # expand vtrhead to avoid load band conflict + chunk = hint.rstep[0] + + tl_hint.block_size_x = block_row_warps + tl_hint.block_size_y = block_col_warps + tl_hint.thread_row_tiles = thread_row_tiles + tl_hint.thread_col_tiles = thread_col_tiles + tl_hint.vthread_row_tiles = vthread_row_tiles + tl_hint.vthread_col_tiles = vthread_col_tiles + tl_hint.chunk = chunk + + return tl_hint + + def get_config_params(self): + return { + "block_size_x": self.block_size_x, + "block_size_y": self.block_size_y, + "thread_row_tiles": self.thread_row_tiles, + "thread_col_tiles": self.thread_col_tiles, + "chunk": self.chunk, + } + + def __repr__(self): + return ("{" + f"block_size_x: {self.block_size_x}, " + f"block_size_y: {self.block_size_y}, " + f"thread_row_tiles: {self.thread_row_tiles}, " + f"thread_col_tiles: {self.thread_col_tiles}, " + f"chunk: {self.chunk}" + "}") + + def serialze_hints_to_configs(self, hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + def with_default_config(self) -> PrimFunc: + block_size_x = getattr(self, "block_size_x", 2) + block_size_y = getattr(self, "block_size_y", 2) + thread_row_tiles = getattr(self, "thread_row_tiles", 16) + thread_col_tiles = getattr(self, "thread_col_tiles", 16) + chunk = getattr(self, "chunk", 16) + + return self.apply_config( + block_size_x=block_size_x, + block_size_y=block_size_y, + thread_row_tiles=thread_row_tiles, + thread_col_tiles=thread_col_tiles, + chunk=chunk, + ) + + def apply_config( + self, + block_size_x: Optional[int] = None, + block_size_y: Optional[int] = None, + thread_row_tiles: Optional[int] = None, + thread_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + ): + assert block_size_x is not None, "block_size_x must be provided" + assert block_size_y is not None, "block_size_y must be provided" + assert thread_row_tiles is not None, "thread_row_tiles must be provided" + assert thread_col_tiles is not None, "thread_col_tiles must be provided" + assert chunk is not None, "chunk must be provided" + + M, N, K = self.M, self.N, self.K + in_dtype, out_dtype, accum_dtype = ( + self.in_dtype, + self.out_dtype, + self.accum_dtype, + ) + + shared_scope = "shared.dyn" + + block_M = block_size_x * thread_row_tiles + block_N = block_size_y * thread_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + C_shape = (M, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + + threads = thread_row_tiles * thread_col_tiles + local_size_a = block_M // thread_row_tiles + local_size_b = block_N // thread_col_tiles + local_size_c = (block_M // thread_row_tiles) * (block_N // thread_col_tiles) + + micro_size_k = 128 // DataType(in_dtype).bits + + dp4a_size = 4 + use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + + A_local = T.alloc_local((local_size_a, micro_size_k), in_dtype) + B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype) + C_local = T.alloc_local((local_size_c,), accum_dtype) + + thread_binding = T.thread_binding(threads, "threadIdx.x") + + warp_m = thread_binding % thread_row_tiles + warp_n = thread_binding // thread_row_tiles + + T.clear(C_local) + + for ko in T.serial(K // block_K): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - # Tiling and Other Optimization Parameters - num_stages: int = 2 - enable_rasterization: bool = False + for ki in T.serial((block_K // micro_size_k)): + for i in T.serial(local_size_a): + for mk in T.vectorized(micro_size_k): + A_local[i, mk] = A_shared[warp_m * local_size_a + i, + ki * micro_size_k + mk] - def with_default_config(self): - raise NotImplementedError + for i in T.serial(local_size_b): + for mk in T.vectorized(micro_size_k): + B_local[i, mk] = B_shared[warp_n * local_size_b + i, + ki * micro_size_k + mk] - def apply_config(self,): + for i, j in T.grid(local_size_a, local_size_b): + for mk in T.serial(micro_size_k // dp4a_size): + if use_dp4a: + T.dp4a( + A_local[i, mk * dp4a_size], + B_local[j, mk * dp4a_size], + C_local[i * local_size_b + j], + ) + else: + for dp4a_idx in T.serial(dp4a_size): + C_local[i * local_size_b + j] += ( + A_local[i, mk * dp4a_size + dp4a_idx] * + B_local[j, mk * dp4a_size + dp4a_idx]) - # M, N, K = self.M, self.N, self.K - # trans_A, trans_B = self.trans_A, self.trans_B - # in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + for i, j in T.grid(local_size_a, local_size_b): + C[ + by * block_M + warp_m * local_size_a + i, + bx * block_N + warp_n * local_size_b + j, + ] = C_local[i * local_size_b + j] - raise NotImplementedError + return self.maybe_simplify(main) def __post_init__(self): # Validate the matrix transpose settings assert self.trans_A is False, "Currently only support Matrix A not transposed" assert self.trans_B is True, "Currently only support Matrix B transposed" + assert self.with_bias is False, "Currently only support without bias" return diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 1b083eafb..c5e0ec2d3 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -28,8 +28,7 @@ @dataclass -class MatmulScheduler(BaseScheduler): - +class MatmulBaseScheduler(BaseScheduler): # OP Related Config M: Optional[int] = None N: Optional[int] = None @@ -39,6 +38,51 @@ class MatmulScheduler(BaseScheduler): in_dtype: str = "float16" out_dtype: str = "float16" accum_dtype: str = "float16" + with_bias: bool = False + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + + # Simple TIR Compute Expression + ir_module = matmul_select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + ) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + return self.serialze_hints_to_configs(roller_hints) + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + return self.get_roller_configs(arch, topk) + + # check if required shared memory cache + def check_require_cache(self) -> bool: + with_bias = self.with_bias + + conditions: List[bool] = [] + conditions.append(False) + # Bias Add should be done in shared memory + conditions.append(with_bias) + return any(conditions) # Always set to False Currently + + +@dataclass +class MatmulBlockScheduler(MatmulBaseScheduler): # Default Tile Related Params block_M: int = 64 @@ -126,42 +170,12 @@ def get_configs_sm80(self): configs = [{**c, 'num_stages': num_stages} for c in configs] return configs - def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): - layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" - - # Simple TIR Compute Expression - ir_module = matmul_select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - layout=layout, - ) - - roller_hints = get_roller_hints_from_func( - ir_module, - arch, - topk, - tensorcore_only=True, - allow_gemv=True, - ) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - def serialze_hints_to_configs(hints: List[Hint]): - configs = [] - for hint in hints: - config = self.TLHint.from_roller_hint(hint) - configs.append(config) - return configs - - return serialze_hints_to_configs(roller_hints) - - def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): - return self.get_roller_configs(arch, topk) + def serialze_hints_to_configs(self, hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs def with_default_config(self): block_M = getattr(self, "block_M", 64) @@ -199,9 +213,12 @@ def apply_config( M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) + C_shape = (M, N) + Bias_shape = (N,) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @@ -209,7 +226,8 @@ def apply_config( def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + C: T.Buffer(C_shape, out_dtype), + Bias: T.Buffer(Bias_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): @@ -230,6 +248,11 @@ def main( else: T.copy(B[k * block_K, bx * block_N], B_shared) T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] += Bias[bx * block_N + j] + T.copy(C_local, C[by * block_M, bx * block_N]) return self.maybe_simplify(main) @@ -240,20 +263,10 @@ def __post_init__(self): @dataclass -class MatmulFineGrainScheduler(BaseScheduler): +class MatmulFineGrainScheduler(MatmulBaseScheduler): # Fine-grained matrix multiplication scheduler # Allows for more detailed configuration. - # Operation Configuration - M: Optional[int] = None - N: Optional[int] = None - K: Optional[int] = None - in_dtype: str = "float16" - out_dtype: str = "float16" - trans_A: bool = False - trans_B: bool = True - accum_dtype: str = "float16" - # Tensor Core Warp Configuration block_row_warps: int = 2 block_col_warps: int = 2 @@ -325,47 +338,12 @@ def __repr__(self): f"enable_rasterization={self.enable_rasterization}" "}") - def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): - layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" - - M = self.M - # This is a hack to utilize tensor core - if isinstance(M, int) and M < 16: - M = 16 - - # Simple TIR Compute Expression - ir_module = matmul_select_implementation( - M=M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - layout=layout, - ) - - roller_hints = get_roller_hints_from_func( - ir_module, - arch, - topk, - tensorcore_only=True, - allow_gemv=True, - ) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - def serialze_hints_to_configs(hints: List[Hint]): - configs = [] - for hint in hints: - config = self.TLHint.from_roller_hint(hint) - configs.append(config) - return configs - - return serialze_hints_to_configs(roller_hints) - - def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): - return self.get_roller_configs(arch, topk) + def serialze_hints_to_configs(self, hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs def with_default_config(self): block_row_warps = getattr(self, "block_row_warps", 2) @@ -409,6 +387,7 @@ def apply_config( M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias # Calculate the micro size per warp using a helper function micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) @@ -420,6 +399,8 @@ def apply_config( # Define the shapes of matrices and shared memory buffers A_shape = (M, K) B_shape = (N, K) + C_shape = (M, N) + Bias_shape = (N,) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) C_shared_shape = ( @@ -454,12 +435,16 @@ def apply_config( chunk=chunk, ) + # cache_write_required = self.check_require_cache() + cache_write_required = False + # Define the main kernel using the generated configuration @T.prim_func def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + C: T.Buffer(C_shape, out_dtype), + Bias: T.Buffer(Bias_shape, out_dtype), ): # Grid and thread configuration for CUDA kernel with T.Kernel( @@ -521,21 +506,41 @@ def main( # Matrix multiplication on fragments mma_emitter.mma(A_local, B_local, C_local) - # Store the result back to C shared memory - mma_emitter.stmatrix( - C_local, - C_shared, - thread_bindings=thread_bindings, - ) - - # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + if cache_write_required: + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Do bias addition + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] += Bias[bx * block_N + j] + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + # Store the result directly to global memory + mma_emitter.stmatrix( + C_local, + C, + thread_bindings=thread_bindings, + pid_m=by, + pid_n=bx, + ) return self.maybe_simplify(main) @@ -567,6 +572,7 @@ def apply_config( M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias # Calculate the micro size per warp using a helper function micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) @@ -586,6 +592,9 @@ def apply_config( # Define the shapes of matrices and shared memory buffers A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + C_shape = (M, N) + Bias_shape = (N,) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) B_shared_shape = ( block_N // micro_size_y, @@ -628,12 +637,14 @@ def apply_config( transform_kind_b=self.weight_transform_kind, ) + cache_write_required = self.check_require_cache() # Define the main kernel using the generated configuration @T.prim_func def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + C: T.Buffer(C_shape, out_dtype), + Bias: T.Buffer(Bias_shape, out_dtype), ): # Grid and thread configuration for CUDA kernel with T.Kernel( @@ -704,21 +715,41 @@ def main( # Matrix multiplication on fragments mma_emitter.mma(A_local, B_local, C_local) - # Store the result back to C shared memory - mma_emitter.stmatrix( - C_local, - C_shared, - thread_bindings=thread_bindings, - ) - - # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + if cache_write_required: + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Do bias addition + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] += Bias[bx * block_N + j] + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + # Store the result directly to global memory + mma_emitter.stmatrix( + C_local, + C, + thread_bindings=thread_bindings, + pid_m=by, + pid_n=bx, + ) return self.maybe_simplify(main) @@ -748,6 +779,7 @@ def matmul_blocked( ): A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) + C_shape = (M, N) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @@ -755,7 +787,7 @@ def matmul_blocked( def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py index f4943bfe0..5aa3cab82 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -21,6 +21,11 @@ MatmulINT4DequantizeWeightPropagationScheduler, # noqa: F401 ) +from bitblas.base.roller import TileDevice +from bitblas.base.arch import ( + is_ampere_arch, + is_volta_arch, +) from bitblas.ops.common import TransformKind from typing import Union @@ -39,7 +44,54 @@ def is_non_transform_kind(kind) -> bool: return kind == TransformKind.NonTransform -def select_scheduler( +def volta_select_scheduler( + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a: Union[int, TransformKind] = TransformKind.NonTransform, + propagate_b: Union[int, TransformKind] = TransformKind.NonTransform, +): + ''' + Fine-grained Interface is preferred as it provides more flexibility + and can be used to implement high performance kernel. + ''' + if isinstance(propagate_a, int): + propagate_a = TransformKind(propagate_a) + if isinstance(propagate_b, int): + propagate_b = TransformKind(propagate_b) + + trans_A, trans_B = parse_layout(layout) + + def check_if_not_supported(): + conditions = [True] + conditions.append(propagate_a == TransformKind.NonTransform) + conditions.append(propagate_b == TransformKind.NonTransform) + conditions.append(trans_A is False) + conditions.append(trans_B is True) + conditions.append(in_dtype in ["int8", "float16", "float32"]) + conditions.append(accum_dtype in ["int32", "float32"]) + return all(conditions) + + if not check_if_not_supported(): + raise ValueError(f"Unsupported configuration: {layout}, {propagate_a}, {propagate_b}") + + raise NotImplementedError + + +def ampere_select_scheduler( M=None, N=1024, K=1024, @@ -67,8 +119,6 @@ def select_scheduler( propagate_a = TransformKind(propagate_a) if isinstance(propagate_b, int): propagate_b = TransformKind(propagate_b) - if with_bias: - raise NotImplementedError trans_A, trans_B = parse_layout(layout) @@ -163,3 +213,51 @@ def is_int4_dtype(dtype): ) else: raise ValueError(f"Unsupported configuration: {layout}, {propagate_a}, {propagate_b}") + + +def select_scheduler( + arch: TileDevice, + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a: Union[int, TransformKind] = TransformKind.NonTransform, + propagate_b: Union[int, TransformKind] = TransformKind.NonTransform, +): + if is_ampere_arch(arch): + return ampere_select_scheduler( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + layout=layout, + zeros_mode=zeros_mode, + propagate_a=propagate_a, + propagate_b=propagate_b, + ) + elif is_volta_arch(arch): + raise NotImplementedError + else: + raise ValueError(f"Unsupported target: {arch.name}") diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py index 036ace634..671cd256e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py @@ -17,7 +17,7 @@ _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, + _tir_packed_to_fp4_to_f16, _tir_u8_to_f8_e4m3_to_f16, _tir_packed_to_unsigned_convert_with_zeros, ) @@ -28,8 +28,7 @@ @dataclass -class MatmulDequantizeScheduler(BaseScheduler): - +class MatmulDequantizeBaseScheduler(BaseScheduler): # OP Related Config M: Optional[int] = None N: Optional[int] = None @@ -49,7 +48,60 @@ class MatmulDequantizeScheduler(BaseScheduler): group_size: int = -1 fast_decoding: bool = False with_bias: bool = False - zeros_mode: Literal["original", "rescale", "quantized"] = ("original",) + zeros_mode: Literal["original", "rescale", "quantized"] = "original" + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + + # Simple TIR Compute Expression + ir_module = matmul_dequantize_select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + bit=self.num_bits, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + zeros_mode=self.zeros_mode, + ) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + return self.serialze_hints_to_configs(roller_hints) + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + return self.get_roller_configs(arch, topk) + + # check if required shared memory cache + def check_require_cache(self) -> bool: + with_bias = self.with_bias + + conditions: List[bool] = [] + conditions.append(False) + # Bias Add should be done in shared memory + conditions.append(with_bias) + return any(conditions) # Always set to False Currently + + +@dataclass +class MatmulDequantizeScheduler(MatmulDequantizeBaseScheduler): # Default Tile Related Params block_M: int = 128 @@ -112,51 +164,12 @@ def __repr__(self): f"enable_rasterization={self.enable_rasterization}" "}") - def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): - layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" - - # Simple TIR Compute Expression - ir_module = matmul_dequantize_select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - layout=layout, - bit=self.num_bits, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - zeros_mode=self.zeros_mode, - ) - - roller_hints = get_roller_hints_from_func( - ir_module, - arch, - topk, - tensorcore_only=True, - allow_gemv=True, - ) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - def serialze_hints_to_configs(hints: List[Hint]): - configs = [] - for hint in hints: - config = self.TLHint.from_roller_hint(hint) - configs.append(config) - return configs - - return serialze_hints_to_configs(roller_hints) - - def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): - return self.get_roller_configs(arch, topk) + def serialze_hints_to_configs(self, hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs def with_default_config(self): block_M = getattr(self, "block_M", 64) @@ -202,6 +215,7 @@ def apply_config( self.accum_dtype, ) fast_decoding = self.fast_decoding + with_bias = self.with_bias num_bits = self.num_bits storage_dtype = self.storage_dtype @@ -223,6 +237,7 @@ def apply_config( Scale_shape = (N, K // group_size) Zeros_shape = (N, K // group_size) Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) + C_shape = (M, N) Bias_shape = (N,) A_shared_shape = (block_M, block_K) @@ -246,6 +261,8 @@ def apply_config( assert func_name is not None, "lop3_intrin_info is not found" import_source = self.common_header + import_source + cache_write_required = self.check_require_cache() + @T.prim_func def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), @@ -254,8 +271,8 @@ def general_dequant_matmul( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): @@ -264,6 +281,7 @@ def general_dequant_matmul( B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_shared = T.alloc_shared([block_M, block_N], out_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) tx = T.thread_binding(0, threads, thread="threadIdx.x") @@ -325,7 +343,18 @@ def general_dequant_matmul( T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) - T.copy(C_local, C[by * block_M, bx * block_N]) + if cache_write_required: + T.copy(C_local, C_shared) + if with_bias: + for i, j in T.grid(block_M, block_N): + C_shared[i, j] += Bias[bx * block_N + j] + + T.copy(C_shared, C[by * block_M, bx * block_N]) + else: + if with_bias: + for i, j in T.grid(block_M, block_N): + C_local[i, j] += Bias[bx * block_N + j] + T.copy(C_local, C[by * block_M, bx * block_N]) return self.maybe_simplify(general_dequant_matmul) @@ -364,7 +393,7 @@ def naive_cast_dequant(x): else: dequant_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) elif source_format == "fp": - dequant_func = _tir_u32_to_f4_to_f16 + dequant_func = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit) elif source_format == "fp_e4m3": dequant_func = _tir_u8_to_f8_e4m3_to_f16 else: @@ -408,7 +437,7 @@ def _normal_dequant_impl( qzeros_buffer: T.Buffer, ): for v in T.serial(0, local_size): - index = (i * threads * local_size + tx * local_size + v) + index = i * threads * local_size + tx * local_size + v vi = index // stride_k vj = index % stride_k if not with_scaling: @@ -531,8 +560,10 @@ def _normal_fast_dequant_impl( T.address_of(dequant_weight_local[0]), T.address_of(scale_buffer[pid_n * stride_n, k * stride_k // group_size]), T.address_of(zeros_buffer[pid_n * stride_n, k * stride_k // group_size]), - T.address_of(qzeros_buffer[k * stride_k // group_size, - pid_n * stride_n // num_elems_per_byte]), + T.address_of(qzeros_buffer[ + k * stride_k // group_size, + pid_n * stride_n // num_elems_per_byte, + ]), dtype=in_dtype, ) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index fff815e7d..00f90bd27 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T -from typing import Optional, List, Literal +from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 @@ -12,21 +12,18 @@ from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitter, # noqa: F401 ) -from bitblas.ops.common import TransformKind # noqa: F401 -from bitblas.ops.base_scheduler import BaseScheduler -from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint from bitblas.base.roller.rasterization import NoRasterization -from bitblas.base.utils import get_roller_hints_from_func from dataclasses import dataclass -from bitblas.ops.general_matmul.tirscript import ( - matmul_dequantize_select_implementation,) +from bitblas.ops.general_matmul.tilelang.dequantize.block_primitive_tensorcore import ( + MatmulDequantizeBaseScheduler, # noqa: F401 +) from bitblas.tl.base_hint import BaseTLHint from bitblas.quantization import ( _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, + _tir_packed_to_fp4_to_f16, _tir_u8_to_f8_e4m3_to_f16, _tir_packed_to_unsigned_convert_with_zeros, ) @@ -37,28 +34,7 @@ @dataclass -class MatmulDequantizeFineGrainedScheduler(BaseScheduler): - - # OP Related Config - M: Optional[int] = None - N: Optional[int] = None - K: Optional[int] = None - trans_A: bool = False - trans_B: bool = False - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - - # Dequantize Config - num_bits: int = 4 - storage_dtype: str = "int8" - source_format: str = "uint" - with_scaling: bool = False - with_zeros: bool = False - group_size: int = -1 - fast_decoding: bool = False - with_bias: bool = False - zeros_mode: Literal["original", "rescale", "quantized"] = "original", +class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): # Tensor Core Warp Configuration block_row_warps: int = 2 @@ -131,50 +107,12 @@ def __repr__(self): f"enable_rasterization={self.enable_rasterization}" "}") - def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): - layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" - - # Simple TIR Compute Expression - ir_module = matmul_dequantize_select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - layout=layout, - bit=self.num_bits, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - zeros_mode=self.zeros_mode) - - roller_hints = get_roller_hints_from_func( - ir_module, - arch, - topk, - tensorcore_only=True, - allow_gemv=True, - ) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - def serialze_hints_to_configs(hints: List[Hint]): - configs = [] - for hint in hints: - config = self.TLHint.from_roller_hint(hint) - configs.append(config) - return configs - - return serialze_hints_to_configs(roller_hints) - - def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): - return self.get_roller_configs(arch, topk) + def serialze_hints_to_configs(self, hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs def with_default_config(self): block_row_warps = getattr(self, "block_row_warps", 2) @@ -241,6 +179,7 @@ def apply_config( warp_cols = warp_col_tiles // micro_size_y fast_decoding = self.fast_decoding + with_bias = self.with_bias num_bits = self.num_bits storage_dtype = self.storage_dtype @@ -258,6 +197,7 @@ def apply_config( A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) + C_shape = (M, N) LUT_shape = (group_size, K // num_elems_per_byte) Scale_shape = (N, K // group_size) Zeros_shape = (N, K // group_size) @@ -305,6 +245,8 @@ def apply_config( chunk=chunk, ) + cache_write_required = self.check_require_cache() + @T.prim_func def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), @@ -313,8 +255,8 @@ def general_dequant_matmul( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): @@ -418,22 +360,40 @@ def general_dequant_matmul( # Matrix multiplication on fragments mma_emitter.mma(A_frag, B_frag, C_frag) + if cache_write_required: + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) - # Store the result back to C shared memory - mma_emitter.stmatrix( - C_frag, - C_shared, - thread_bindings=tx, - ) - - # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] += Bias[bx * block_N + j] + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + # Store the result back to C global memory + mma_emitter.stmatrix( + C_frag, + C, + thread_bindings=tx, + pid_m=by, + pid_n=bx, + ) return self.maybe_simplify(general_dequant_matmul) @@ -472,7 +432,7 @@ def naive_cast_dequant(x): else: dequant_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) elif source_format == "fp": - dequant_func = _tir_u32_to_f4_to_f16 + dequant_func = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit) elif source_format == "fp_e4m3": dequant_func = _tir_u8_to_f8_e4m3_to_f16 else: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index afd8849fc..4fea63d10 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -78,6 +78,7 @@ def apply_config( warp_cols = warp_col_tiles // micro_size_y fast_decoding = self.fast_decoding + with_bias = self.with_bias num_bits = self.num_bits storage_dtype = self.storage_dtype @@ -103,6 +104,7 @@ def apply_config( Scale_shape = (N, K // group_size) Zeros_shape = (N, K // group_size) Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) + C_shape = (M, N) Bias_shape = (N,) A_shared_shape = (block_M, block_K) @@ -158,6 +160,8 @@ def apply_config( if block_N * block_K // num_elems_per_byte // threads < vec_load_qb: vec_load_qb = block_N * block_K // num_elems_per_byte // threads + cache_write_required = self.check_require_cache() + @T.prim_func def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), @@ -166,8 +170,8 @@ def general_dequant_matmul( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): @@ -273,21 +277,39 @@ def general_dequant_matmul( # Matrix multiplication on fragments mma_emitter.mma(A_frag, B_dequantize_frag, C_frag) - # Store the result back to C shared memory - mma_emitter.stmatrix( - C_frag, - C_shared, - thread_bindings=tx, - ) + if cache_write_required: + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) - # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] += Bias[j] + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + mma_emitter.stmatrix( + C_frag, + C, + thread_bindings=tx, + pid_m=by, + pid_n=bx, + ) return self.maybe_simplify(general_dequant_matmul) diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py index 0cd17feb3..a9fb00864 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -10,7 +10,7 @@ _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, + _tir_packed_to_fp4_to_f16, _tir_u8_to_f8_e4m3_to_f16, _tir_packed_to_unsigned_convert_with_zeros, ) @@ -228,7 +228,7 @@ def decode(n, k): w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp_e4m3": w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) @@ -417,7 +417,7 @@ def decode_func(n, k): w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp_e4m3": w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) @@ -601,7 +601,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, @@ -803,7 +803,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, diff --git a/bitblas/ops/impl/batch_matmul_dequantize_impl.py b/bitblas/ops/impl/batch_matmul_dequantize_impl.py index 6a5f740a0..dd0ad43d7 100644 --- a/bitblas/ops/impl/batch_matmul_dequantize_impl.py +++ b/bitblas/ops/impl/batch_matmul_dequantize_impl.py @@ -7,7 +7,7 @@ from bitblas.ops.common import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, + _tir_packed_to_unsigned_convert, _tir_packed_to_fp4_to_f16, _tir_u8_to_f8_e4m3_to_f16) @@ -64,7 +64,7 @@ def decode_func(b, n, k): w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp_e4m3": w = _tir_u8_to_f8_e4m3_to_f16(bit, B[b, n, k], dtype=in_dtype) @@ -238,7 +238,7 @@ def decode_func(b, n, k): dtype=in_dtype, ) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B_reindex[b, n, k // n_float_per_elem], k % n_float_per_elem, diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index ec450610a..1bb3f519d 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -10,7 +10,7 @@ _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, + _tir_packed_to_fp4_to_f16, _tir_u8_to_f8_e4m3_to_f16, _tir_packed_to_unsigned_convert_with_zeros, ) @@ -228,7 +228,7 @@ def decode(n, k): w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp_e4m3": w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) @@ -417,7 +417,7 @@ def decode_func(n, k): w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp_e4m3": w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) @@ -598,7 +598,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, @@ -795,7 +795,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, diff --git a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py index bb63b10e5..aed833022 100644 --- a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -7,7 +7,7 @@ from bitblas.ops.common import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, + _tir_packed_to_unsigned_convert, _tir_packed_to_fp4_to_f16, _tir_u8_to_f8_e4m3_to_f16) from typing import Union @@ -65,7 +65,7 @@ def decode_func(n, k): w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp_e4m3": w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) @@ -240,7 +240,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, @@ -449,7 +449,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( + w = _tir_packed_to_fp4_to_f16(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index ba3927005..3f1ed85ce 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -105,16 +105,18 @@ def __init__( self.target = target self.backend = backend - self.ir_module: Optional[IRModule] = ( - self._select_implementation() if self.is_tir_backend() else None) - self.scheduler: Optional[BaseScheduler] = ( - self._select_scheduler() if self.is_tilelang_backend() else None) - self.scheduled_ir_module: Optional[IRModule] = None self.rt_mod: Optional[Module] = None self.time_evaluator: Optional[Callable] = None self.dynamic_range: Optional[Dict] = None self.arch: Optional[TileDevice] = get_arch(target) if target else None + + # selector must be invoked after arch is initialized + self.ir_module: Optional[IRModule] = ( + self._select_implementation() if self.is_tir_backend() else None) + self.scheduler: Optional[BaseScheduler] = ( + self._select_scheduler() if self.is_tilelang_backend() else None) + self.pass_context: Optional[Dict] = None self.kernel_name_generator: Optional[BaseKernelNameGenerator] = ( @@ -369,6 +371,8 @@ def apply_fast_tuning( func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=parallel_build) # Return the best Config as Hint return (best.sch.mod, best.config) if best is not None else (None, None) + else: + raise ValueError(f"Unsupported backend: {self.backend}") def apply_fast_tuning_with_dynamic_range( self, @@ -376,17 +380,30 @@ def apply_fast_tuning_with_dynamic_range( target: Target, topk: int = 20, dynamic_range: Dict[str, List[int]] = None, + parallel_build=True, ): - scheduled_ir_module = fast_tune_with_dynamic_range( - func_or_scheduler, - target, - topk=topk, - parallel_build=True, - dynamic_range=dynamic_range, - kernel_name_generator=self.kernel_name_generator, - ) - if scheduled_ir_module is not None: - return scheduled_ir_module + if self.is_tir_backend(): + scheduled_ir_module = fast_tune_with_dynamic_range( + func_or_scheduler, + target, + topk=topk, + parallel_build=parallel_build, + dynamic_range=dynamic_range, + kernel_name_generator=self.kernel_name_generator, + ) + if scheduled_ir_module is not None: + return scheduled_ir_module + elif self.is_tilelang_backend(): + # Finetune the schedule + tuning_configs = self.get_tl_tuning_config(topk=topk) + assert len(tuning_configs) > 0, "No tuning config found for this operator." + _, best = tl_apply_and_build( + func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=parallel_build) + # Return the best Config as Hint + return (best.sch.mod, best.config) if best is not None else (None, None) + else: + raise ValueError(f"Unsupported backend: {self.backend}") + return None def hardware_aware_finetune( @@ -404,7 +421,9 @@ def hardware_aware_finetune( self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range( func, target, topk, dynamic_range) elif self.is_tilelang_backend(): - raise NotImplementedError("Not support dynamic range for tilelang backend") + func = self.scheduler.with_default_config() + self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range( + func, target, topk, dynamic_range) else: func_or_scheduler = (self.prim_func if self.is_tir_backend() else self.scheduler) scheduled_mod, best_hint = self.apply_fast_tuning( diff --git a/bitblas/quantization/__init__.py b/bitblas/quantization/__init__.py index 48059c8bd..5760695be 100644 --- a/bitblas/quantization/__init__.py +++ b/bitblas/quantization/__init__.py @@ -4,7 +4,7 @@ _tir_packed_int_to_int_convert, # noqa: F401 _tir_packed_to_signed_convert, # noqa: F401 _tir_packed_to_unsigned_convert, # noqa: F401 - _tir_u32_to_f4_to_f16, # noqa: F401 + _tir_packed_to_fp4_to_f16, # noqa: F401 _tir_u8_to_f8_e4m3_to_f16, # noqa: F401 _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 ) diff --git a/bitblas/quantization/quantization.py b/bitblas/quantization/quantization.py index f6fc75b4e..0b98a23ba 100644 --- a/bitblas/quantization/quantization.py +++ b/bitblas/quantization/quantization.py @@ -123,21 +123,37 @@ def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) -def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): +def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 assert dtype == "float16" assert val.dtype == "uint32" # e_f4 == 0 -> e_f16 = 0 # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask - s = f4 >> tir.const(3, "uint32") - e_f4 = f4 & tir.const(7, "uint32") - e_f16 = e_f4 | tir.const(8, "uint32") + mask = tvm.tir.const((1 << nbit) - 1, "uint16") + f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask + s = f4 >> tir.const(3, "uint16") + e_f4 = f4 & tir.const(7, "uint16") + e_f16 = e_f4 | tir.const(8, "uint16") val_f16 = tir.reinterpret("float16", - (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) - return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16")) + return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16) + +def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + f4 = ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(storage_dtype) + f4 = (val >> (pos.astype(storage_dtype) * tir.const(nbit, storage_dtype))) & mask + s = f4 >> tir.const(3, storage_dtype) + e_f4 = f4 & tir.const(7, storage_dtype) + e_f16 = e_f4 | tir.const(8, storage_dtype) + val_f16 = tir.reinterpret("float16", + ((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype("uint16")) + return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, "float16"), val_f16) + + return f_convert def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 diff --git a/bitblas/tl/mma_macro_generator.py b/bitblas/tl/mma_macro_generator.py index fd8ec43ae..edad06f75 100644 --- a/bitblas/tl/mma_macro_generator.py +++ b/bitblas/tl/mma_macro_generator.py @@ -2,10 +2,10 @@ # Licensed under the MIT License. import tvm.tl.language as T - -from typing import Union +from typing import Union, Tuple, Optional from bitblas.ops.common import TransformKind from tvm import DataType +from tvm.tir import PrimExpr from tvm.runtime import convert from .utils import ( mma_store_index_map, @@ -33,20 +33,24 @@ class TensorCoreIntrinEmitter(object): "e5m2_float8": "e5m2", } + # Represent the thread binding in the form of (tx, warp_n, warp_m) + is_m_first = False + def __init__( self, - a_dtype="float16", - b_dtype="float16", - accum_dtype="float16", - a_transposed=False, - b_transposed=False, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=8, - warp_col_tiles=8, - chunk=16, - reduce_k=1, - num_elems_per_byte=1, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: Optional[bool] = False, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -64,10 +68,12 @@ def __init__( self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) self._initialize_mma_prefix(self.k_dim) self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) + self._initialize_is_m_first(is_m_first) + self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k - self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) + self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.num_elems_per_byte = num_elems_per_byte def _initialize_k_dim(self, a_dtype="float16"): @@ -98,17 +104,50 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): self.micro_size_y = n_dim self.micro_size_k = k_dim - def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): + def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): + if is_m_first is not None: + self.is_m_first = is_m_first + + def extract_thread_binding(self, + thread_id, + is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: + """ + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + """ WARP_SIZE = self.WARP_SIZE block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_col_warps, + (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps, + ) + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = ( + thread_id % WARP_SIZE, + (thread_id // WARP_SIZE) % block_row_warps, + (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps, + ) + return lane_id, warp_n, warp_m + + def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows chunk = self.chunk micro_size_x = self.micro_size_x micro_size_k = self.micro_size_k + local_size_a = self.local_size_a a_dtype = self.a_dtype a_transposed = self.a_transposed - local_size_a = self.local_size_a @T.macro def _warp_ldmatrix_a( @@ -119,9 +158,7 @@ def _warp_ldmatrix_a( rk=0, ): stride = A_shared_buf.shape[-1] - tx = thread_bindings % WARP_SIZE - ty = (thread_bindings // WARP_SIZE) % block_row_warps - + tx, _, warp_m = self.extract_thread_binding(thread_bindings) for i in T.serial(warp_rows): T.ptx_ldmatrix( a_dtype, @@ -131,7 +168,7 @@ def _warp_ldmatrix_a( A_local_buf.data, i * local_size_a, T.address_of(A_shared_buf[ - ty * warp_row_tiles + i * micro_size_x, + warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k, ]), get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), @@ -140,10 +177,6 @@ def _warp_ldmatrix_a( return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): - - WARP_SIZE = self.WARP_SIZE - block_row_warps = self.block_row_warps - block_col_warps = self.block_col_warps warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols chunk = self.chunk @@ -162,13 +195,12 @@ def _warp_ldmatrix_b( rk=0, ): stride = B_shared_buf.shape[-1] - tx = thread_bindings % WARP_SIZE - tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps + tx, warp_n, _ = self.extract_thread_binding(thread_bindings) for j in T.serial(warp_cols): # Assign B_shared_elem ri, rj = ( - tz * warp_col_tiles + j * micro_size_y, + warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) B_shared_elem = B_shared_buf[ri, rj] @@ -237,33 +269,50 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): return _warp_mma(A_local_buf, B_local_buf, C_local_buf) - def stmatrix(self, C_local_buf, C_shared_buf, thread_bindings): - WARP_SIZE = self.WARP_SIZE + def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None): block_row_warps = self.block_row_warps block_col_warps = self.block_col_warps warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_out = self.local_size_out + is_global = pid_m is not None and pid_n is not None + BLOCK_M = block_row_warps * warp_rows + BLOCK_N = block_col_warps * warp_cols + M_DIM, N_DIM = self.M_DIM, self.N_DIM + # STS # MMA Store must be in simulated instead of TVM Intrins # As TVM Intrins is like a hack that the threadIdx.x should be always # equal to the warp_size @T.macro - def _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings): - tx = thread_bindings % WARP_SIZE - ty = (thread_bindings // WARP_SIZE) % block_row_warps - tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps + def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings): + tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, + col] = C_local_buf[i * (warp_cols * local_size_out) + + j * local_size_out + local_id] + + @T.macro + def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings): + tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) for i, j in T.grid(warp_rows, warp_cols): for local_id_o in T.serial(local_size_out // 2): for local_id_i in T.vectorized(2): local_id = local_id_o * 2 + local_id_i row, col = T.meta_var(mma_store_index_map(tx, local_id)) - C_shared_buf[ty * warp_rows + i, tz * warp_cols + j, row, - col] = C_local_buf[i * (warp_cols * local_size_out) + - j * local_size_out + local_id] + C_buf[ + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, + (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col, + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + + local_id] - return _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings) + return (_warp_stmatrix_global(C_local_buf, C_buf, thread_bindings) + if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings)) class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): @@ -274,20 +323,21 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): def __init__( self, - a_dtype="float16", - b_dtype="float16", - accum_dtype="float16", - a_transposed=False, - b_transposed=False, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=8, - warp_col_tiles=8, - chunk=16, - reduce_k=1, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: Optional[bool] = False, transform_kind_a: Union[int, TransformKind] = 0, transform_kind_b: Union[int, TransformKind] = 0, - num_elems_per_byte=1, ): super().__init__( a_dtype=a_dtype, @@ -302,6 +352,7 @@ def __init__( chunk=chunk, reduce_k=reduce_k, num_elems_per_byte=num_elems_per_byte, + is_m_first=is_m_first, ) self._initialize_transform_kind(transform_kind_a, transform_kind_b) @@ -352,9 +403,6 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): assert transform_kind_b in [0, 3], "Currently only support 0 and 3" def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): - WARP_SIZE = self.WARP_SIZE - block_row_warps = self.block_row_warps - block_col_warps = self.block_col_warps warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols chunk = self.chunk @@ -375,14 +423,13 @@ def _warp_ldmatrix_b( rk=0, ): stride = B_shared_buf.shape[-1] - tx = thread_bindings % WARP_SIZE - tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps + tx, warp_n, _ = self.extract_thread_binding(thread_bindings) if transform_kind_b < TransformKind.LDMatrixTransform: for j in T.serial(warp_cols): # Assign B_shared_elem ri, rj = ( - tz * warp_col_tiles + j * micro_size_y, + warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) ni, nj, nii, njj = ( @@ -391,7 +438,7 @@ def _warp_ldmatrix_b( (ri) % micro_size_y, (rj) % micro_size_k, ) - args = ((ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj)) + args = (ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj) B_shared_elem = B_shared_buf[args] T.ptx_ldmatrix( @@ -410,7 +457,7 @@ def _warp_ldmatrix_b( for local_id in T.vectorized(local_size_dequantize): # Assign B_shared_elem ri, rj = ( - tz * warp_cols + j, + warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki, ) rii, rjj = (tx * local_size_dequantize + @@ -491,16 +538,16 @@ def mma(self, A_local_buf, B_local_buf, C_local_buf): @T.macro def _warp_mma(A_local_buf, B_local_buf, C_local_buf): for i, j in T.grid(warp_rows, warp_cols): - ''' - A[16, 32], B[16, 32], C[16, 16] - A_local_size -> 16 - B_local_size -> 16 - C_local_size -> 8 - For each m16n8k32 inst - For A: m16k32 consume 16 int4 elements -> 8 A_local_size - For A: n8k32 consume 8 int4 elements -> 4 B_local_size - For C: m16n8 consume 4 int32 elements -> 4 C_local_size - ''' + """ + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + """ # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] T.ptx_mma( @@ -595,16 +642,16 @@ def mma(self, A_local_buf, B_local_buf, C_local_buf): @T.macro def _warp_mma(A_local_buf, B_local_buf, C_local_buf): for i, j in T.grid(warp_rows, warp_cols): - ''' - A[16, 32], B[16, 32], C[16, 16] - A_local_size -> 16 - B_local_size -> 16 - C_local_size -> 8 - For each m16n8k32 inst - For A: m16k32 consume 16 int4 elements -> 8 A_local_size - For A: n8k32 consume 8 int4 elements -> 4 B_local_size - For C: m16n8 consume 4 int32 elements -> 4 C_local_size - ''' + """ + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + """ # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] T.ptx_mma( diff --git a/integration/BitNet/vllm_workspace/conftest.py b/integration/BitNet/vllm_workspace/conftest.py index c99f334cb..4ddc637e6 100644 --- a/integration/BitNet/vllm_workspace/conftest.py +++ b/integration/BitNet/vllm_workspace/conftest.py @@ -97,11 +97,9 @@ def should_do_global_cleanup_after_test(request) -> bool: to initialize torch. """ - if request.node.get_closest_marker("skip_global_cleanup"): + if not request.node.get_closest_marker("skip_global_cleanup"): return False - return True - @pytest.fixture(autouse=True) def cleanup_fixture(should_do_global_cleanup_after_test: bool): diff --git a/testing/python/operators/test_general_matmul_ops_backend_tl.py b/testing/python/operators/test_general_matmul_ops_backend_tl.py index 3e9d55530..6ed36b427 100644 --- a/testing/python/operators/test_general_matmul_ops_backend_tl.py +++ b/testing/python/operators/test_general_matmul_ops_backend_tl.py @@ -230,20 +230,31 @@ def test_matmul_codegen_default(): False, False, None), matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), - # FP32 Accum - matmul_codegen_default(768, 768, 768, "float16", "float16", "float32", "float16", "nt", False, - -1, False, False, None), - # INT32 Accum + matmul_codegen_default(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, + False, None), matmul_codegen_default(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, + False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, True, "original"), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, + False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, True, "original"), def test_matmul_finetune(): matmul_finetune(1024, 1024, 1024, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None, False) - matmul_finetune(1024, 1024, 1024, "float16", "float16", "float16", "float16", "nt", False, -1, - False, False, None, False) - def test_matmul_torch_forward(): matmul_torch_forward(1024, 1024, 1024, "float16", "float16", "float16", "float16", "nt", None, diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 857b22270..2669d6c2c 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -5,7 +5,7 @@ import bitblas.testing from tvm import tl from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulScheduler, + MatmulBlockScheduler, MatmulFineGrainScheduler, MatmulWeightPropagationScheduler, ) @@ -41,7 +41,7 @@ def assert_matmul_blocked_with_default_correctness( out_dtype="float16", accum_dtype="float16", ): - matmul = MatmulScheduler( + matmul = MatmulBlockScheduler( M=M, N=N, K=K, @@ -92,7 +92,7 @@ def assert_matmul_blocked_apply_config_correctness( threads=128, enable_rasterization=False, ): - matmul = MatmulScheduler( + matmul = MatmulBlockScheduler( M=M, N=N, K=K, @@ -1398,7 +1398,6 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( qw = qw.reshape(qw_shape) permuted_inputs.append(torch.from_numpy(qw).cuda()) if with_scaling: - # permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda()) zeros = None diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py index adb0b057f..03767a4f3 100644 --- a/testing/python/operators/test_general_matmul_tilelang_scheduler.py +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -5,7 +5,7 @@ import bitblas.testing from tvm.ir import structural_equal from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulScheduler,) + MatmulBlockScheduler,) from bitblas.ops.general_matmul.tilelang.dequantize import (MatmulDequantizeScheduler) @@ -17,7 +17,7 @@ def assert_dense_scheduler_simplify(M, in_dtype="float16", out_dtype="float16", accum_dtype="float16"): - matmul = MatmulScheduler( + matmul = MatmulBlockScheduler( M=M, N=N, K=K, @@ -28,7 +28,7 @@ def assert_dense_scheduler_simplify(M, accum_dtype=accum_dtype, ).deactivate_simplify().with_default_config() - simplified = MatmulScheduler.Simplify(matmul) + simplified = MatmulBlockScheduler.Simplify(matmul) is_equal = structural_equal(matmul, simplified) if is_equal: diff --git a/testing/python/tilelang/test_tilelang_gemm_simt.py b/testing/python/tilelang/test_tilelang_gemm_simt.py new file mode 100644 index 000000000..a1ff6b098 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_gemm_simt.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul_simt( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + # This is a debug config + block_size_x = 8 + block_size_y = 8 + thread_row_tiles = 16 + thread_col_tiles = 16 + chunk = 16 + + shared_scope = "shared" + + block_M = block_size_x * thread_row_tiles + block_N = block_size_y * thread_col_tiles + block_K = chunk + + # Pipeline Stage + + A_shape = (M, K) + B_shape = (N, K) + C_shape = (M, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + + threads = thread_row_tiles * thread_col_tiles + local_size_a = block_M // thread_row_tiles + local_size_b = block_N // thread_col_tiles + local_size_c = (block_M // thread_row_tiles) * (block_N // thread_col_tiles) + + micro_size_k = 128 // DataType(in_dtype).bits + dp4a_size = 4 + use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + + A_local = T.alloc_local((local_size_a, micro_size_k), in_dtype) + B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype) + C_local = T.alloc_local((local_size_c,), accum_dtype) + + thread_binding = T.thread_binding(threads, "threadIdx.x") + + warp_m = thread_binding % thread_row_tiles + warp_n = thread_binding // thread_row_tiles + + T.clear(C_local) + + for ko in T.serial(K // block_K): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial((block_K // micro_size_k)): + for i in T.serial(local_size_a): + for mk in T.vectorized(micro_size_k): + A_local[i, mk] = A_shared[warp_m * local_size_a + i, + ki * micro_size_k + mk] + + for i in T.serial(local_size_b): + for mk in T.vectorized(micro_size_k): + B_local[i, mk] = B_shared[warp_n * local_size_b + i, + ki * micro_size_k + mk] + + for i, j in T.grid(local_size_a, local_size_b): + for mk in T.serial(micro_size_k // dp4a_size): + if use_dp4a: + T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], + C_local[i * local_size_b + j]) + else: + for dp4a_idx in T.serial(dp4a_size): + C_local[i * local_size_b + + j] += A_local[i, mk * dp4a_size + + dp4a_idx] * B_local[j, mk * dp4a_size + + dp4a_idx] + + for i, j in T.grid(local_size_a, local_size_b): + C[by * block_M + warp_m * local_size_a + i, + bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + print(src_code) + # src_code is the generated cuda source + assert src_code is not None + + if in_dtype == "int8": + A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") + + +if __name__ == "__main__": + bitblas.testing.main()