diff --git a/format.sh b/format.sh index 915a3416f..c5e81a1ef 100755 --- a/format.sh +++ b/format.sh @@ -143,7 +143,7 @@ else # Check spelling only of the files that changed in last commit. spell_check_changed fi -echo 'BitBLAS codespell: Done' +echo 'bitblas codespell: Done' echo 'bitblas ruff: Check Start' # Lint specified files diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index a8b8c4986..4638f10bc 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -1,9 +1,14 @@ -import argparse from tvm import tl import tvm.tl.language as T from tvm.tl.autotuner import * from functools import partial import itertools +import torch +import bitblas +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) def get_configs(): @@ -22,13 +27,28 @@ def get_configs(): return configs -def ref_program(Q, K, V, casual): +def ref_program(Q, K, V, causal): from flash_attn.flash_attn_interface import flash_attn_func - return flash_attn_func(Q, K, V, causal=casual) + return flash_attn_func(Q, K, V, causal=causal) + + +def ref_flashattn_result(batch, heads, seq_len, dim, is_casual, dtype="float16"): + q_shape = (batch, seq_len, heads, dim) + k_shape = (batch, seq_len, heads, dim) + v_shape = (batch, seq_len, heads, dim) + typemap = {"float16": torch.float16} + Q = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(q_shape).type( + typemap[dtype]).cuda() + K = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(k_shape).type( + typemap[dtype]).cuda() + V = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(v_shape).type( + typemap[dtype]).cuda() + res = ref_program(Q, K, V, is_casual) + return res -def flashattn(batch, heads, seq_len, dim, is_casual): +def flashattn_autotune(batch, heads, seq_len, dim, is_causal): @autotune( configs=get_configs(), @@ -39,7 +59,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual): @jit( out_idx=[3], supply_type=tl.TensorSupplyType.Normal, - ref_prog=partial(ref_program, casual=is_casual), + ref_prog=partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01, ) @@ -81,10 +101,10 @@ def main( Q_local[i, j] *= scale loop_range = ( T.ceildiv( - (bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)) + (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - if is_casual: + if is_causal: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( bx * block_M + i >= k * block_N + j, @@ -128,23 +148,112 @@ def main( return kernel() +@bitblas.testing.requires_cuda_compute_version(8, 9) +def test_flashattn_autotune(): + flashattn_autotune(1, 4, 256, 256, True) + flashattn_autotune(1, 8, 256, 256, True) + flashattn_autotune(4, 4, 256, 256, True) + flashattn_autotune(4, 8, 256, 256, True) + + +def flashattn(batch, heads, seq_len, dim, is_causal): + + def kernel(block_M=64, block_N=64, num_stages=1, thread_num=128): + scale = (1.0 / dim)**0.5 * 1.44269504 + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + Output: T.Buffer(shape, dtype), + ): + print(type(seq_len), seq_len) + print(type(block_M), block_M) + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.copy(Q_shared, Q_local) + for i, j in T.Parallel(block_M, dim): + Q_local[i, j] *= scale + loop_range = ( + T.ceildiv( + (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + 0, + -T.infinity(acc_s.dtype), + ) + else: + T.clear(acc_s) + T.gemm( + Q_local, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(acc_s, acc_s_cast) + T.gemm( + acc_s_cast, + V_shared, + acc_o, + policy=T.GemmWarpPolicy.FullRow, + ) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + + return main + + mod, params = tl.lower(kernel()) + mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) + + +@bitblas.testing.requires_cuda_compute_version(8, 9) +def test_flashattn(): + flashattn(1, 4, 256, 256, True) + flashattn(1, 8, 256, 256, True) + flashattn(4, 4, 256, 256, True) + flashattn(4, 8, 256, 256, True) + + if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--batch", type=int, default=64, help="Batch size") - parser.add_argument("--h", type=int, default=12, help="Number of heads") - parser.add_argument("--n_ctx", type=int, default=2048, help="Context size") - parser.add_argument("--d_head", type=int, default=256, help="Head dimension") - parser.add_argument("--casual", type=bool, default=True, help="Casual flag") - args = parser.parse_args() - BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head - casual = args.casual - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD - total_flops = 2 * flops_per_matmul - if casual: - total_flops *= 0.5 - - best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual) - print(f"Best latency: {best_latency}") - print(f"Best TFlops: {total_flops / best_latency * 1e-9}") - print(f"Best config: {best_config}") - print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") + bitblas.testing.main()