diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index c97f5c9ec..d4ffbafc9 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -693,6 +693,10 @@ def sch_dequantize_in_register_with_config( V compute """ + weight_transform_kind = config.intrin_info.weight_transform_kind + if weight_transform_kind == TransformKind.LDMatrixTransform: + return self.sch_warp_memory_prefetch_with_config(func, config) + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel get_mma_intrin_group,) from .intrin import get_lop3_intrin_group diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 227de7ad3..eea256fd9 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -383,14 +383,20 @@ def with_default_config(self): def apply_config( self, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=32, - warp_col_tiles=32, - chunk=16, - num_stages=2, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + num_stages: Optional[int] = None, enable_rasterization=False, ): + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" + assert num_stages is not None, "num_stages is required" M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B @@ -534,6 +540,9 @@ def __post_init__(self): @dataclass class MatmulWeightPropagationScheduler(MatmulFineGrainScheduler): + # Ladder Transform Config + weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + def apply_config( self, block_row_warps=2, @@ -604,7 +613,7 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, - transform_kind_b=TransformKind.LDMatrixTransform, + transform_kind_b=self.weight_transform_kind, ) # Define the main kernel using the generated configuration diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py index bc13c9d4c..9fe99512c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -5,6 +5,14 @@ MatmulDequantizeScheduler, # noqa: F401 ) +from .finegrained_primitive_tensorcore import ( + MatmulDequantizeFineGrainedScheduler, # noqa: F401 +) + +from .ladder_weight_transform_tensorcore import ( + MatmulDequantizeWeightPropagationScheduler, # noqa: F401 +) + from bitblas.ops.common import TransformKind from typing import Union diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py index fce026c51..5f1a8f5ed 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py @@ -312,7 +312,6 @@ def general_dequant_matmul( Zeros, Qzeros, local_size, - local_size_compressed, bx, tx, k, @@ -384,7 +383,6 @@ def _normal_dequant( zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, local_size: int, - local_size_compressed: int, pid_n: T.Var, tx: T.Var, k: T.Var, @@ -413,9 +411,9 @@ def _normal_dequant_impl( qzeros_buffer: T.Buffer, ): for v in T.serial(0, local_size): - index = (i * threads * local_size_compressed + tx * local_size_compressed + v) - vi = index // (stride_k // num_elems_per_byte) - vj = index % (stride_k // num_elems_per_byte) + index = (i * threads * local_size + tx * local_size + v) + vi = index // stride_k + vj = index % stride_k if not with_scaling: dequant_weight_local[v] = self._decode_func( num_bits, @@ -486,12 +484,9 @@ def _normal_fast_dequant( qzeros_buffer: T.Buffer, func_name: str, pid_n: T.Var, - tx: T.Var, k: T.Var, - i: T.Var, stride_n: int, stride_k: int, - threads: int, ): num_elems_per_byte = self.num_elems_per_byte with_scaling = self.with_scaling diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index c98474ec0..d755ba2f8 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -11,7 +11,6 @@ from bitblas.tl.macro_generator import ( TensorCoreIntrinEmitter, # noqa: F401 - TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) from bitblas.ops.common import TransformKind # noqa: F401 from bitblas.ops.base_scheduler import BaseScheduler @@ -31,13 +30,14 @@ _tir_u8_to_f8_e4m3_to_f16, _tir_packed_to_unsigned_convert_with_zeros, ) +from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group # GPU warp configuration for NVIDIA GPUs warp_size = 32 @dataclass -class MatmulDequantizeScheduler(BaseScheduler): +class MatmulDequantizeFineGrainedScheduler(BaseScheduler): # OP Related Config M: Optional[int] = None @@ -60,12 +60,15 @@ class MatmulDequantizeScheduler(BaseScheduler): with_bias: bool = False zeros_mode: Literal["original", "rescale", "quantized"] = "original", - # Default Tile Related Params - block_M: int = 64 - block_N: int = 64 - block_K: int = 32 + # Tensor Core Warp Configuration + block_row_warps: int = 2 + block_col_warps: int = 2 + warp_row_tiles: int = 64 + warp_col_tiles: int = 64 + chunk: int = 32 # Usually determines the K-dimension split size + + # Other Optimization Parameters num_stages: int = 2 - threads: int = 128 enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): @@ -88,36 +91,43 @@ def from_roller_hint(cls, hint: Hint): block_row_warps = block[0] // warp[0] block_col_warps = block[1] // warp[1] - warp_size = 32 # NVIDIA GPU warp size is 32 + warp_row_tiles = warp[0] + warp_col_tiles = warp[1] + chunk = rstep[0] + if num_stages == 1: num_stages = 0 # disable pipelining - tl_hint.block_M = block[0] - tl_hint.block_N = block[1] - tl_hint.block_K = rstep[0] + tl_hint.block_row_warps = block_row_warps + tl_hint.block_col_warps = block_col_warps + tl_hint.warp_row_tiles = warp_row_tiles + tl_hint.warp_col_tiles = warp_col_tiles + tl_hint.chunk = chunk tl_hint.num_stages = num_stages - tl_hint.threads = warp_size * block_row_warps * block_col_warps tl_hint.enable_rasterization = enable_rasterization return tl_hint def get_config_params(self): return { - "block_M": self.block_M, - "block_N": self.block_N, - "block_K": self.block_K, + "block_row_warps": self.block_row_warps, + "block_col_warps": self.block_col_warps, + "warp_row_tiles": self.warp_row_tiles, + "warp_col_tiles": self.warp_col_tiles, + "chunk": self.chunk, "num_stages": self.num_stages, - "threads": self.threads, "enable_rasterization": self.enable_rasterization, } def __repr__(self): return ("{" - f"block_M={self.block_M}," - f"block_N={self.block_N}," - f"block_K={self.block_K}," + f"block_M={self.block_row_warps * self.warp_row_tiles}," + f"block_N={self.block_col_warps * self.warp_col_tiles}," + f"warp_M={self.warp_row_tiles}," + f"warp_N={self.warp_col_tiles}," + f"block_K={self.chunk}," + f"threads={self.block_row_warps * self.block_col_warps * warp_size}," f"num_stages={self.num_stages}," - f"threads={self.threads}," f"enable_rasterization={self.enable_rasterization}" "}") @@ -167,56 +177,71 @@ def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): return self.get_roller_configs(arch, topk) def with_default_config(self): - block_M = getattr(self, "block_M", 64) - block_N = getattr(self, "block_N", 64) - block_K = getattr(self, "block_K", 32) + block_row_warps = getattr(self, "block_row_warps", 2) + block_col_warps = getattr(self, "block_col_warps", 2) + warp_row_tiles = getattr(self, "warp_row_tiles", 32) + warp_col_tiles = getattr(self, "warp_col_tiles", 32) + chunk = getattr(self, "chunk", 32) num_stages = getattr(self, "num_stages", 2) - threads = getattr(self, "threads", 128) enable_rasterization = getattr(self, "enable_rasterization", False) return self.apply_config( - block_M=block_M, - block_N=block_N, - block_K=block_K, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, num_stages=num_stages, - threads=threads, enable_rasterization=enable_rasterization, ) - def _apply_config_dequant_only( + def apply_config( self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, + enable_rasterization=False, ): - assert block_M is not None, "block_M is required" - assert block_N is not None, "block_N is required" - assert block_K is not None, "block_K is required" + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" assert num_stages is not None, "num_stages is required" - assert threads is not None, "threads is required" + M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B + assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - # check is dequantize only + in_dtype, out_dtype, accum_dtype = ( + self.in_dtype, + self.out_dtype, + self.accum_dtype, + ) + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) - def check_is_dequantize_only(): - return not self.with_scaling + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + threads = warp_size * (block_row_warps * block_col_warps) - if not check_is_dequantize_only(): - raise ValueError("Not a Dequantize Only Configuration") + fragement_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y - in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + fast_decoding = self.fast_decoding num_bits = self.num_bits storage_dtype = self.storage_dtype + source_format = self.source_format storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - num_elems_per_byte = 8 // num_bits + num_elems_per_byte = self.num_elems_per_byte MAX_TRANSACTION_SIZE_IN_BITS = 128 local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits @@ -227,50 +252,139 @@ def check_is_dequantize_only(): group_size = K A_shape = (M, K) - B_shape = (N, K // storage_nbit * num_bits) + B_shape = (N, K // num_elems_per_byte) + LUT_shape = (group_size, K // num_elems_per_byte) + Scale_shape = (N, K // group_size) + Zeros_shape = (N, K // group_size) + Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) + Bias_shape = (N,) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = self.common_header + import_source + + # Configure the tensor core intrinsic emitter + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) @T.prim_func - def main( + def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, storage_dtype), + LUT: T.Buffer(LUT_shape, in_dtype), + Scale: T.Buffer(Scale_shape, in_dtype), + Qzeros: T.Buffer(Qzeros_shape, storage_dtype), + Zeros: T.Buffer(Zeros_shape, in_dtype), + Bias: T.Buffer(Bias_shape, in_dtype), C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + + A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) tx = T.thread_binding(0, threads, thread="threadIdx.x") + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + }) + T.use_swizzle(10, enable=enable_rasterization) - T.clear(C_local) + T.import_source(import_source) + + T.clear(C_frag) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared) for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): - index = i * threads * local_size_compressed + tx * local_size_compressed + v + index = ( + i * threads * local_size_compressed + tx * local_size_compressed + + v) vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] - for v in T.serial(0, local_size): - B_dequantize_local[v] = self._decode_func( - num_bits, - B_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, + + if fast_decoding is True: + self._normal_fast_dequant( + B_local, + B_dequantize_local, + Scale, + Zeros, + Qzeros, + func_name, + by, + tx, + ko, + i, + block_N, + block_K, + threads, + ) + else: + self._normal_dequant( + B_local, + B_dequantize_local, + Scale, + Zeros, + Qzeros, + local_size, + local_size_compressed, + bx, + tx, + ko, + i, + block_N, + block_K, + threads, ) for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v @@ -278,92 +392,45 @@ def main( vj = index % block_K B_dequantize_shared[vi, vj] = B_dequantize_local[v] - T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - def _apply_config_with_scaling( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - raise NotImplementedError("Scaling Configuration is not implemented") - - def _apply_config_with_scaling_zeros_original_or_rescale( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - raise NotImplementedError("Scaling and Zeros Original Configuration is not implemented") - - def _apply_config_with_scaling_zeros_quantized( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - raise NotImplementedError("Scaling and Zeros Rescale Configuration is not implemented") - - def apply_config( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - assert block_M is not None, "block_M is required" - assert block_N is not None, "block_N is required" - assert block_K is not None, "block_K is required" - assert num_stages is not None, "num_stages is required" - assert threads is not None, "threads is required" - trans_A, trans_B = self.trans_A, self.trans_B - - assert trans_A is False, "Dequantize only implement for trans_A=False currently" - assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - - with_scaling = self.with_scaling - with_zeros = self.with_zeros - zeros_mode = self.zeros_mode - - args = [block_M, block_N, block_K, num_stages, threads, enable_rasterization] - - dequant_prim_func = None - if not with_scaling: - dequant_prim_func = self._apply_config_dequant_only(*args) - - if not with_zeros: - dequant_prim_func = self._apply_config_with_scaling(*args) - - if zeros_mode in ["original", "rescale"]: - dequant_prim_func = self._apply_config_with_scaling_zeros_original_or_rescale(*args) - elif zeros_mode == "quantized": - dequant_prim_func = self._apply_config_with_scaling_zeros_quantized(*args) - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - if dequant_prim_func is None: - raise ValueError("Unsupported Configuration") - - return self.maybe_simplify(dequant_prim_func) + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=tx, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=tx, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(general_dequant_matmul) @property def _decode_func(self): @@ -375,7 +442,7 @@ def _decode_func(self): source_format = self.source_format storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - bit = self.bit + num_bits = self.num_bits dequant_func = None @@ -385,17 +452,17 @@ def naive_cast_dequant(x): if with_zeros and zeros_mode == "quantized": dequant_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) elif source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed + if num_bits == 8: + # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant else: dequant_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) elif source_format == "int": - if bit == 1: + if num_bits == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. dequant_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) - elif bit == 8: - # 8 bit does not need to be compressed + elif num_bits == 8: + # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant else: dequant_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) @@ -408,6 +475,190 @@ def naive_cast_dequant(x): return dequant_func + def _normal_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + local_size: int, + local_size_compressed: int, + pid_n: T.Var, + tx: T.Var, + k: T.Var, + i: T.Var, + stride_n: int, + stride_k: int, + threads: int, + ): + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + num_bits = self.num_bits + in_dtype = self.in_dtype + group_size = self.group_size + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + + @T.macro + def _normal_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + for v in T.serial(0, local_size): + index = (i * threads * local_size + tx * local_size + v) + vi = index // (stride_k) + vj = index % (stride_k) + if not with_scaling: + dequant_weight_local[v] = self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + elif not with_zeros: + # Scaling only + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size]) + elif zeros_mode == "original": + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) - zeros_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // + group_size]) * scale_buffer[pid_n * stride_n + vi, + (k * stride_k + vj) // group_size] + elif zeros_mode == "rescale": + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size] - + zeros_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size]) + elif zeros_mode == "quantized": + dequant_qzeros = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros, + dtype=in_dtype, + )) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size] + + return _normal_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + def _normal_fast_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + func_name: str, + pid_n: T.Var, + tx: T.Var, + k: T.Var, + i: T.Var, + stride_n: int, + stride_k: int, + threads: int, + ): + # TODO(lei): un-used arguments should be removed + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + in_dtype = self.in_dtype + group_size = self.group_size + + @T.macro + def _normal_fast_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + if not with_scaling: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + dtype=in_dtype, + ) + elif not with_zeros: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + T.address_of(scale_buffer[pid_n * stride_n, k * stride_k // group_size]), + dtype=in_dtype, + ) + elif zeros_mode in ["original", "rescale"]: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + T.address_of(scale_buffer[pid_n * stride_n, k * stride_k // group_size]), + T.address_of(zeros_buffer[pid_n * stride_n, k * stride_k // group_size]), + dtype=in_dtype, + ) + elif zeros_mode == "quantized": + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + T.address_of(scale_buffer[pid_n * stride_n, k * stride_k // group_size]), + T.address_of(zeros_buffer[pid_n * stride_n, k * stride_k // group_size]), + T.address_of(qzeros_buffer[k * stride_k // group_size, + pid_n * stride_n // num_elems_per_byte]), + dtype=in_dtype, + ) + + return _normal_fast_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + @property + def num_elems_per_byte(self): + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + num_bits = self.num_bits + return storage_nbit // num_bits + def __post_init__(self): - # Add Config Validation - return + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index e69de29bb..bb463e59a 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -0,0 +1,540 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional +from bitblas.tl.utils import ( + get_mma_micro_size, # noqa: F401 + make_swizzle_layout, # noqa: F401 +) +from .finegrained_primitive_tensorcore import MatmulDequantizeFineGrainedScheduler +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 +) +from bitblas.ops.common import TransformKind # noqa: F401 +from dataclasses import dataclass +from bitblas.quantization import ( + _tir_packed_to_unsigned_convert,) +from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group +from bitblas.gpu.matmul_analysis import ( + get_propagate_map, + get_ladder_stage3_map, +) + +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulDequantizeWeightPropagationScheduler(MatmulDequantizeFineGrainedScheduler): + + # Ladder Transform Config + weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + + def apply_config( + self, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + num_stages: Optional[int] = None, + enable_rasterization=False, + ): + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" + assert num_stages is not None, "num_stages is required" + + M, N, K = self.M, self.N, self.K + trans_A, trans_B = self.trans_A, self.trans_B + weight_transform_kind = self.weight_transform_kind + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + assert (weight_transform_kind == TransformKind.LDMatrixTransform + ), "Dequantize only implement for LDMatrixTransform currently" + + in_dtype, out_dtype, accum_dtype = ( + self.in_dtype, + self.out_dtype, + self.accum_dtype, + ) + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + threads = warp_size * (block_row_warps * block_col_warps) + + fragement_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + fast_decoding = self.fast_decoding + + num_bits = self.num_bits + storage_dtype = self.storage_dtype + source_format = self.source_format + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = self.num_elems_per_byte + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + + group_size = self.group_size + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = ( + N // micro_size_y, + K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + LUT_shape = (group_size, K // num_elems_per_byte) + Scale_shape = (N, K // group_size) + Zeros_shape = (N, K // group_size) + Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) + Bias_shape = (N,) + + A_shared_shape = (block_M, block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + storage_scope="warp", # to get the ladder transform lop3 intrin + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = self.common_header + import_source + + # Configure the tensor core intrinsic emitter with ladder transform + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=weight_transform_kind, + num_elems_per_byte=num_elems_per_byte, + ) + + vec_load_qb = 16 + if block_N * block_K // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * block_K // num_elems_per_byte // threads + + @T.prim_func + def general_dequant_matmul( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + LUT: T.Buffer(LUT_shape, in_dtype), + Scale: T.Buffer(Scale_shape, in_dtype), + Qzeros: T.Buffer(Qzeros_shape, storage_dtype), + Zeros: T.Buffer(Zeros_shape, in_dtype), + Bias: T.Buffer(Bias_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + + A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size // num_elems_per_byte), + storage_dtype) + B_dequantize_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + + tx = T.thread_binding(0, threads, thread="threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + }) + + T.use_swizzle(10, enable=enable_rasterization) + + T.import_source(import_source) + + T.clear(C_frag) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + T.copy(A[by * block_M, ko * block_K], A_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + idx = i * threads * vec_load_qb + tx * vec_load_qb + v + vkk = idx % (micro_size_k // num_elems_per_byte) + vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( + block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // + (block_K // micro_size_k)) % ( + block_N // micro_size_y) + B_shared[vj, vk, vjj, vkk] = B[ + bx * (block_N // micro_size_y) + vj, + ko * (block_K // micro_size_k) + vk, + vjj, + vkk, + ] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=tx, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_frag, + B_shared, + ki, + thread_bindings=tx, + ) + + if fast_decoding is True: + self._normal_fast_dequant( + B_frag, + B_dequantize_frag, + Scale, + Zeros, + Qzeros, + func_name, + local_size, + warp_cols, + bx, + tx, + mma_emitter, + ko, + ki, + block_N, + block_K, + ) + else: + self._normal_dequant( + B_frag, + B_dequantize_frag, + Scale, + Zeros, + Qzeros, + local_size, + warp_cols, + bx, + tx, + mma_emitter, + ko, + ki, + block_N, + block_K, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_frag, B_dequantize_frag, C_frag) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(general_dequant_matmul) + + def _normal_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + local_size: int, + warp_cols: int, + pid_n: T.Var, + thread_bindings: T.Var, + mma_emitter: TensorCoreIntrinEmitterWithLadderTransform, + ko: T.Var, + ki: T.Var, + stride_n: int, + stride_k: int, + ): + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + num_bits = self.num_bits + in_dtype = self.in_dtype + group_size = self.group_size + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + micro_size_k = mma_emitter.micro_size_k + k_inner_stride = micro_size_k // local_size + + @T.macro + def _normal_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + for j in T.serial(warp_cols): + for v in T.serial(0, local_size): + tx = thread_bindings % mma_emitter.WARP_SIZE + tz = (thread_bindings // (mma_emitter.WARP_SIZE * mma_emitter.block_row_warps) + ) % mma_emitter.block_col_warps + vi = ( + tz * (warp_cols * mma_emitter.WARP_SIZE // k_inner_stride) + j * + (mma_emitter.WARP_SIZE // k_inner_stride) + (tx // k_inner_stride)) + vj = ki * micro_size_k + (tx % k_inner_stride) * local_size + v + remaped_i, remaped_j = self.get_param_indices( + pid_n * stride_n + vi, + ko * stride_k + vj, + transform_kind=TransformKind.LDMatrixTransform, + in_dtype=in_dtype, + matrix_name="B", + group_size=group_size, + ) + if not with_scaling: + dequant_weight_local[j * local_size + v] = self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + elif not with_zeros: + dequant_weight_local[j * local_size + v] = ( + self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[remaped_i, remaped_j]) + elif zeros_mode == "original": + dequant_weight_local[j * local_size + v] = (self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) - zeros_buffer[remaped_i, remaped_j]) * scale_buffer[remaped_i, remaped_j] + elif zeros_mode == "rescale": + dequant_weight_local[j * local_size + v] = ( + self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[remaped_i, remaped_j] - + zeros_buffer[remaped_i, remaped_j]) + elif zeros_mode == "quantized": + dequant_qzeros = _tir_packed_to_unsigned_convert( + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + remaped_i, + remaped_j // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + + dequant_weight_local[j * local_size + v] = (self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros, + dtype=in_dtype, + )) * scale_buffer[remaped_i, remaped_j] + + return _normal_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + def _normal_fast_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + func_name: str, + local_size: int, + warp_cols: int, + pid_n: T.Var, + thread_bindings: T.Var, + mma_emitter: TensorCoreIntrinEmitterWithLadderTransform, + ko: T.Var, + ki: T.Var, + stride_n: int, + stride_k: int, + ): + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + in_dtype = self.in_dtype + group_size = self.group_size + micro_size_k = mma_emitter.micro_size_k + k_inner_stride = micro_size_k // local_size + grouped_k = scale_buffer.shape[-1] + + @T.macro + def _normal_fast_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + for j in T.serial(warp_cols): + tx = thread_bindings % mma_emitter.WARP_SIZE + tz = (thread_bindings // (mma_emitter.WARP_SIZE * mma_emitter.block_row_warps) + ) % mma_emitter.block_col_warps + vi = ( + tz * (warp_cols * mma_emitter.WARP_SIZE // k_inner_stride) + j * + (mma_emitter.WARP_SIZE // k_inner_stride) + (tx // k_inner_stride)) + vj = ki * micro_size_k + (tx % k_inner_stride) * local_size + remapped_i, remapped_j = self.get_param_indices( + pid_n * stride_n + vi, + ko * stride_k + vj, + transform_kind=TransformKind.LDMatrixTransform, + in_dtype=in_dtype, + matrix_name="B", + group_size=group_size, + ) + if not with_scaling: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), + T.address_of(dequant_weight_local[j * local_size]), + dtype=in_dtype, + ) + elif not with_zeros: + # Scaling only + T.call_extern( + func_name, + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), + T.address_of(dequant_weight_local[j * local_size]), + T.address_of(scale_buffer[remapped_i, remapped_j]), + local_size * grouped_k, + local_size, + dtype=in_dtype, + ) + elif zeros_mode in ["original", "rescale"]: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), + T.address_of(dequant_weight_local[j * local_size]), + T.address_of(scale_buffer[remapped_i, remapped_j]), + T.address_of(zeros_buffer[remapped_i, remapped_j]), + local_size * grouped_k, + local_size, + dtype=in_dtype, + ) + # TODO: Implement quantized zeros + + return _normal_fast_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + def get_param_indices( + self, + rl, + rr, + l=16, + r=16, + transform_kind=TransformKind.LDMatrixTransform, # noqa: E741 + trans=True, + in_dtype="float16", + matrix_name="B", + group_size=1, + ): # noqa: E741 + intra_index_map, _ = get_propagate_map(trans=trans, dtype=in_dtype, matrix_name=matrix_name) + + ladder_stage3_index_map, ladder_stage3_inverse_index_map = ( + get_ladder_stage3_map(dtype=in_dtype)) + + # assume the param layout is n, k + + warp_i, warp_j = rl % l, rr % r + + spatial_i, spatial_j = rl // l, rr // r + + # If is stage3 ladder transform + if transform_kind > 2: + warp_i, warp_j = ladder_stage3_inverse_index_map.map_indices([warp_i, warp_j]) + + warp_i, warp_j = intra_index_map.map_indices([warp_i, warp_j]) + new_indices = ( + spatial_i * l + warp_i, + (spatial_j * r + warp_j) // group_size, + ) + + return new_indices + + def __post_init__(self): + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 58f595984..1a83c0d18 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -370,7 +370,7 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( layout="nt", zeros_mode="original", ): - assert with_scaling, "Currently The test only support with scaling" + if group_size == -1: group_size = K propagate_b = 3 @@ -408,10 +408,10 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( matmul_func, config=bitblas.base.Hint.from_dict({ "arch": arch, - "block": [16, 128], - "warp": [16, 32], - "rstep": [128], - "pipeline_stage": 4, + "block": [128, 128], + "warp": [64, 64], + "rstep": [32], + "pipeline_stage": 2, "use_async": True, "intrin_info": intrin_info, "shared_scope": "shared.dyn", @@ -419,7 +419,7 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( "b": 8, "a": 8 }, - "block_reduction_depth": 2, + "block_reduction_depth": 1, }), ) @@ -429,6 +429,8 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( "tir.disable_cse_tir": True }): rt_mod = tvm.build(block_reduce_sch.mod, target=target) + src_code = rt_mod.imported_modules[0].get_source() + assert src_code is not None check_reduce(rt_mod) @@ -500,28 +502,38 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( transformed_b = transformed_b.cuda() c = c.cuda() scale = scale.cuda() - if zeros is not None: + args = [a, transformed_b] + if with_scaling: + args.append(scale) + if with_scaling and with_zeros: zeros = zeros.cuda() - torch_func(a, transformed_b, scale, zeros, c) - else: - torch_func(a, transformed_b, scale, c) - - rescale_b = torch.empty_like(b, dtype=torch.float16) - for i in range(N): - for j in range(K): - if with_zeros: - if zeros_mode == "original": - rescale_b[i, - j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // group_size] - elif zeros_mode == "rescale": - rescale_b[i, - j] = b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size] + args.append(zeros) + args.append(c) + + torch_func(*args) + + args = [a] + if with_scaling: + + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + if zeros_mode == "original": + rescale_b[i, + j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // group_size] + elif zeros_mode == "rescale": + rescale_b[i, + j] = b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size] + else: + raise NotImplementedError else: - raise NotImplementedError - else: - rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) - ref_c = torch.matmul(a, rescale_b.t().cuda()) + ref_c = torch.matmul(*args) print("rescale_b is \n", c) print("ref_c is \n", ref_c) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 280c170ac..31c3de7d1 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -11,13 +11,18 @@ ) from bitblas.ops.general_matmul.tilelang.dequantize import ( - MatmulDequantizeScheduler,) + MatmulDequantizeScheduler, + MatmulDequantizeFineGrainedScheduler, + MatmulDequantizeWeightPropagationScheduler, +) import torch import torch.backends torch.manual_seed(0) +verbose = False + def assert_matmul_blocked_with_default_correctness( M, @@ -166,64 +171,6 @@ def assert_matmul_fine_grained_with_default_correctness( torch.matmul(A, B.T).to(getattr(torch, out_dtype)) if trans_B else torch.matmul(A, B).to( getattr(torch, out_dtype))) - # from bitblas.ops import Matmul, MatmulConfig - # matmul_config = MatmulConfig( - # M=M, - # N=N, - # K=K, - # propagate_a=False, - # propagate_b=False, - # ) - # matmul = Matmul(matmul_config, enable_tuning=False) - # prim_func = matmul.prim_func - # intrin_info = bitblas.base.hint.IntrinInfo( - # in_dtype=in_dtype, - # out_dtype=accum_dtype, - # trans_b=True, - # input_transform_kind=0, - # weight_transform_kind=0, - # ) - - # arch = bitblas.base.CUDA(target="cuda") - - # sch = bitblas.gpu.MatmulTensorizationMMA().apply_config( - # prim_func, - # config=bitblas.base.Hint.from_dict({ - # "arch": arch, - # "block": [64, 64], - # "warp": [32, 32], - # "rstep": [32], - # "pipeline_stage": 2, - # "use_async": True, - # "intrin_info": intrin_info, - # "shared_scope": "shared.dyn", - # "vectorize": { - # "b": 8, - # "a": 8 - # }, - # }), - # ) - - # with tvm.transform.PassContext(config={ - # "tir.use_async_copy": True, - # "tir.merge_static_smem": False - # }): - # rt_mod = tvm.build(sch.mod, target="cuda") - # from tvm.contrib.dlpack import to_pytorch_func - - # torch_func = to_pytorch_func(rt_mod) - - # matmul_c = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - # torch_func(A, B, matmul_c) - - # with open("debug/matmul_ref.cu", "w") as f: - # f.write(rt_mod.imported_modules[0].get_source()) - - # with open("debug/matmul_tl.cu", "w") as f: - # f.write(src_code) - - # torch.testing.assert_close(matmul_c, ref_c, rtol=1e0, atol=1e-1) - torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e-1) @@ -439,6 +386,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( ): import numpy as np from bitblas.quantization import general_compress, interleave_weight + matmul = MatmulDequantizeScheduler( M=M, N=N, @@ -462,7 +410,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None - + print(src_code) input_shape = (M, K) weight_shape = (N, K) output_shape = (M, N) @@ -496,17 +444,17 @@ def assert_matmul_blocked_dequant_with_default_correctness( if with_scaling: if group_size == -1: group_size = K - permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda()) if with_zeros: if zeros_mode == "original": permuted_inputs.append( torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) elif zeros_mode == "rescale": - original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) scaled_zeros = original_zeros * permuted_inputs[-1] permuted_inputs.append(scaled_zeros) elif zeros_mode == "quantized": - original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros) qzeros = general_compress( original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) permuted_inputs.append(torch.from_numpy(qzeros).cuda()) @@ -521,7 +469,30 @@ def assert_matmul_blocked_dequant_with_default_correctness( print(permuted_inputs[-1]) - ref_result = torch.matmul(inputs[0], inputs[1].t().to(torch.float16)) + args = [inputs[0]] + b = inputs[1] + if with_scaling: + scale = permuted_inputs[2] + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + zeros = permuted_inputs[3] + if zeros_mode == "original": + rescale_b[i, j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // + group_size] + elif zeros_mode == "rescale": + rescale_b[i, j] = ( + b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size]) + else: + raise NotImplementedError + else: + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) + + ref_result = torch.matmul(*args) print(ref_result) if zeros_mode == "rescale": @@ -530,6 +501,289 @@ def assert_matmul_blocked_dequant_with_default_correctness( torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e2) +def assert_matmul_fine_grained_dequant_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", +): + import numpy as np + from bitblas.quantization import general_compress, interleave_weight + + matmul = MatmulDequantizeFineGrainedScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + maxq = 2**(bit - 1) + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().to(torch.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + qw = general_compress(intweight.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + # lop3 transformation + if fast_decoding: + qw = interleave_weight(qw, bit, target_dtype=in_dtype) + permuted_inputs.append(torch.from_numpy(qw).cuda()) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + if with_zeros: + if zeros_mode == "original": + permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda()) + elif zeros_mode == "rescale": + original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif zeros_mode == "quantized": + original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros) + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + + permuted_inputs.append(inputs[2]) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(*permuted_inputs) + + print(permuted_inputs[-1]) + + args = [inputs[0]] + b = inputs[1] + if with_scaling: + scale = permuted_inputs[2] + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + if zeros_mode == "original": + rescale_b[i, j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // + group_size] + elif zeros_mode == "rescale": + rescale_b[i, j] = ( + b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size]) + else: + raise NotImplementedError + else: + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) + + ref_result = torch.matmul(*args) + + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e2) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e2) + + +def assert_matmul_weight_transform_dequant_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", +): + import numpy as np + from bitblas.quantization import general_compress, interleave_weight + + matmul = MatmulDequantizeWeightPropagationScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, + ).with_default_config() + if verbose: + print(matmul) + mod, params = tl.lower(matmul) + + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + if verbose: + print(src_code) + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + maxq = 2**(bit - 1) + if group_size == -1: + group_size = K + + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().to(torch.int8) + if source_format == "int": + intweight = intweight + maxq + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + storage_dtype=storage_dtype, + propagate_kind="B", + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(intweight.cpu()).cuda().reshape(N, K) + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + qw = general_compress(LB.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + + # lop3 transformation + if fast_decoding: + qw = interleave_weight(qw, bit, target_dtype=in_dtype) + qw_shape = [int(v) for v in matmul.buffer_map[matmul.params[1]].shape] + qw = qw.reshape(qw_shape) + permuted_inputs.append(torch.from_numpy(qw).cuda()) + if with_scaling: + # permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda()) + + zeros = None + if with_zeros: + if zeros_mode == "original": + zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * maxq + elif zeros_mode == "rescale": + scale = permuted_inputs[2] + original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * maxq) + zeros = -(original_zeros * scale.cuda()) + else: + raise NotImplementedError + + if with_scaling and with_zeros: + permuted_inputs.append(zeros) + + permuted_inputs.append(inputs[2]) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(*permuted_inputs) + + print(permuted_inputs[-1]) + + args = [inputs[0]] + b = inputs[1] + + if with_scaling: + scale = permuted_inputs[2] + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + zeros = permuted_inputs[3] + if zeros_mode == "original": + rescale_b[i, j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // + group_size] + elif zeros_mode == "rescale": + rescale_b[i, j] = ( + b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size]) + else: + raise NotImplementedError + else: + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) + + ref_result = torch.matmul(*args) + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e0) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e0) + + def test_matmul_blocked(): # Default assert_matmul_blocked_with_default_correctness(1024, 1024, 1024) @@ -569,11 +823,25 @@ def test_matmul_blocked_dequant_with_default(): assert_matmul_blocked_dequant_with_default_correctness( 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True) assert_matmul_blocked_dequant_with_default_correctness( - 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True, with_zeros=True) + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + with_zeros=True, + ) assert_matmul_blocked_dequant_with_default_correctness( 1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True) assert_matmul_blocked_dequant_with_default_correctness( - 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True, fast_decoding=True) + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + ) assert_matmul_blocked_dequant_with_default_correctness( 1024, 1024, @@ -582,7 +850,79 @@ def test_matmul_blocked_dequant_with_default(): bit=4, with_scaling=True, with_zeros=True, - fast_decoding=True) + fast_decoding=True, + ) + + +def test_matmul_fine_grained_dequant_with_default(): + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=2) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + with_zeros=True, + ) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + ) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + with_zeros=True, + fast_decoding=True, + ) + + +def test_matmul_weight_transform_dequant_with_default(): + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=2) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True, with_zeros=True) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + ) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + with_zeros=True, + ) if __name__ == "__main__": diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index bb3f38d24..620ef5be7 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -172,18 +172,18 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( micro_size_k = 32 # This is a debug config - block_row_warps = 1 - block_col_warps = 4 + block_row_warps = 2 + block_col_warps = 2 - warp_rows = 1 - warp_cols = 2 + warp_rows = 4 + warp_cols = 4 warp_row_tiles = micro_size_x * warp_rows warp_col_tiles = micro_size_y * warp_cols shared_scope = "shared.dyn" # Pipeline Stage stage = 2 - reduce_k = 2 + reduce_k = 1 block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles @@ -423,6 +423,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct # Get Reference Result ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print("Ref C: ", ref_c) + print("C: ", C) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -437,5 +439,4 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): if __name__ == "__main__": - # bitblas.testing.main() - run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128) + bitblas.testing.main()