diff --git a/3rdparty/tvm b/3rdparty/tvm index 08cce5bb7..2852f55c2 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 08cce5bb7443deedb7992a5f1c643f6e1838e8bf +Subproject commit 2852f55c268db21dc4e9d3e18aae65823c1157e6 diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py new file mode 100644 index 000000000..89a4aefbd --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .matmul import matmul_blocked # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py new file mode 100644 index 000000000..14efbae07 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +import tvm.tl.language as T + +from bitblas.tl.utils import ( + get_mma_micro_size, + make_swizzle_layout, +) + +from bitblas.tl.macro_generator import (TensorCoreIntrinEmitter) + + +def maybe_pipeline( + iterable, + num_stages, +): + enable_pipeline = num_stages > 1 + if enable_pipeline: + return T.Pipelined(iterable, num_stages=num_stages) + else: + return T.serial(iterable) + + +def matmul_blocked( + M, + N, + K, + block_M=64, + block_N=64, + block_K=32, + trans_A=False, + trans_B=False, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization=False, # Enhance L2 Locality +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + if enable_rasterization: + # rasterization factor + T.use_swizzle(10) + + T.clear(C_local) + for k in maybe_pipeline(T.ceildiv(K, block_K), num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def matmul_macro_tensorcore( + M, + N, + K, + dtypeAB, + dtypeC, + accum_dtype, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + num_stages=2, + enable_rasterization=False, +): + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y) + + warp_size = 32 # nvidia gpu warp size is 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, shared_scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, shared_scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, shared_scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + if enable_rasterization: + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in maybe_pipeline(T.ceildiv(K, block_K), num_stages): + + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y] + + return main diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py new file mode 100644 index 000000000..01a729e9c --- /dev/null +++ b/bitblas/tl/mma_layout.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm import arith +from tvm import DataType +from typing import Union, Literal + + +def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): + row = thread_id % 16 + col = 8 * (thread_id // 16) + local_id % 8 + return row, col + + +def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (thread_id // 16) + (thread_id % 8) + col = 8 * ((thread_id % 16) // 8) + local_id % 8 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): + row = thread_id % 16 + col = local_id + (thread_id // 16) * 16 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + 16 * ((thread_id % 16) // 8) + return row, col + + +def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (local_id % 4 // 2) + (thread_id // 4) + col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) + return row, col + + +def shared_16x16_to_mma_32x8_smoothlayout(i, j): + return (i * 2 + j // 8, j % 8) + + +def shared_16x32_to_mma_32x16_smoothlayout(i, j): + return (i * 2 + j // 16, j % 16) + + +def shared_32x16_to_mma_32x16_smoothlayout(i, j): + return (i * 2 + j // 16, j % 16) diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index b41d7ff7d..4b8b4cf6e 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -1,8 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + from tvm import arith from tvm import DataType +import tvm.tl.language as T from typing import Union, Literal +from .mma_layout import ( + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, + mma_store_32x8_to_shared_16x16_layout, +) def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): @@ -61,48 +70,6 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) -def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): - row = thread_id % 16 - col = 8 * (thread_id // 16) + local_id % 8 - return row, col - - -def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): - row = 8 * (thread_id // 16) + (thread_id % 8) - col = 8 * ((thread_id % 16) // 8) + local_id % 8 - return row, col - - -def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): - row = thread_id % 16 - col = local_id + (thread_id // 16) * 16 - return row, col - - -def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): - row = (thread_id // 16) * 8 + (thread_id % 8) - col = local_id + 16 * ((thread_id % 16) // 8) - return row, col - - -def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): - row = 8 * (local_id % 4 // 2) + (thread_id // 4) - col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) - return row, col - - -def shared_16x16_to_mma_32x8_smoothlayout(i, j): - return (i * 2 + j // 8, j % 8) - - -def shared_16x32_to_mma_32x16_smoothlayout(i, j): - return (i * 2 + j // 16, j % 16) - - -def shared_32x16_to_mma_32x16_smoothlayout(i, j): - return (i * 2 + j // 16, j % 16) - - def get_ldmatrix_offset( matrix: Literal["A", "B"], row_idx, @@ -129,3 +96,28 @@ def get_ldmatrix_offset( def mma_store_index_map(*args, **kwargs): return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs) + + +def get_mma_micro_size(dtype: Literal["float16", "int8"]): + # TODO(lei): FP8 related precision support. + # Basic Tensor Core Matrix Multiply operation Unit + micro_size_x = micro_size_y = 16 + micro_size_k = 16 + if dtype == "int8": + micro_size_k = 32 + return micro_size_x, micro_size_y, micro_size_k + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py new file mode 100644 index 000000000..314c85100 --- /dev/null +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +import bitblas.testing +from tvm import tl +from bitblas.ops.general_matmul.tilelang.dense import matmul_blocked +import torch +import torch.backends + +torch.manual_seed(0) + + +def assert_tl_matmul_correctness(M, + N, + K, + block_M=64, + block_N=64, + block_K=32, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization=False): + matmul = matmul_blocked( + M, + N, + K, + block_M=block_M, + block_N=block_N, + block_K=block_K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_matmul_blocked(): + # Pipeline + assert_tl_matmul_correctness(1024, 1024, 1024, num_stages=2) + assert_tl_matmul_correctness(1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_tl_matmul_correctness(1024, 1024, 1024, enable_rasterization=True) + + +if __name__ == "__main__": + bitblas.testing.main()