diff --git a/docs/QuickStart.md b/docs/QuickStart.md index 2285a2313..5a57edbb2 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -12,6 +12,9 @@ Here is an example for a $W_{INT4}A_{FP16}$ mixed-precision matrix multiplicatio import bitblas import torch +# enabling debug output + +bitblas.set_debug_level("Debug") matmul_config = bitblas.MatmulConfig( M=1, # M dimension N=1024, # N dimension @@ -125,6 +128,9 @@ Here is an example to define a ```bitblas.Linear``` of $W_{INT4}A_{FP16}$: import bitblas import torch +# enabling debug output +bitblas.set_debug_level("Debug") + model = bitblas.Linear( in_features=1024, out_features=1024, @@ -178,6 +184,9 @@ from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( QuantLinear as CudaOldQuantLinear, ) +# enabling debug output +bitblas.set_debug_level("Debug") + in_features = 1024 out_features = 1024 group_size = 128 diff --git a/python/bitblas/module/__init__.py b/python/bitblas/module/__init__.py index eaf15bc1d..e29c9de0f 100644 --- a/python/bitblas/module/__init__.py +++ b/python/bitblas/module/__init__.py @@ -232,15 +232,18 @@ def forward(self, A, output=None): A = A.half() # can be lifted to post init. self.init_params() - + if output is None: output = torch.empty( A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) A = self.bitblas_matmul.transform_input(A) + stream = torch.cuda.current_stream() + A_void = ctypes.c_void_p(A.data_ptr()) + stream_handle = ctypes.c_void_p(stream.cuda_stream) # m is the product of the last n - 1 dimensions of A - self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m) + self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, stream_handle) return output diff --git a/python/bitblas/ops/general_matmul_splitk.py b/python/bitblas/ops/general_matmul_splitk.py index e951bf126..28e3cbbf2 100644 --- a/python/bitblas/ops/general_matmul_splitk.py +++ b/python/bitblas/ops/general_matmul_splitk.py @@ -160,7 +160,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if output is None: output = torch.empty( - (self.k_split,) + A.shape[:-1] + (self.N,), + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) if scale is not None: @@ -169,7 +169,12 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(zeros) if bias is not None: args.append(bias) - args.append(output) + + sk_output = torch.empty((self.k_split,) + + A.shape[:-1] + (self.N,), + dtype=self.torch_output_dtype, + device=A.device) + args.append(sk_output) if self.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1) @@ -180,7 +185,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.lib is None: self._forward_from_torch_func(*args) self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) - output = torch.sum(output, dim=0) + torch.sum(sk_output, dim=0, out=output) return output def __call__(self, *args: Any, **kwds: Any) -> Any: diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 5b7de9ab0..603a57248 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -171,4 +171,6 @@ def map_torch_type(intype): # fmt: on if __name__ == "__main__": - bitblas.testing.main() + # bitblas.testing.main() + test_matmul_torch_forward_weight_dequantize(1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, + None, None) diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index dd9b29d51..ac3a15a9c 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -41,20 +41,22 @@ def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtyp matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) assert get_codegen_result(matmul) - @pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), + (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None), + (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None), ], ) -def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, - group_size, with_scaling, with_zeros, zeros_mode): - +def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, + layout, with_bias, group_size, with_scaling, with_zeros, + zeros_mode): + import torch + torch.random.manual_seed(0) matmul_config = MatmulConfigWithSplitK( + k_split=SplitK, M=M, N=N, K=K, @@ -70,20 +72,27 @@ def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo zeros_mode=zeros_mode, ) matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) - matmul.hardware_aware_finetune(topk=10) - assert get_codegen_result(matmul) + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5) + + output_bitblas = matmul.forward(*inputs) + output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) + torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) @pytest.mark.parametrize( "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + (1, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False, False, None), - (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + (4, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False, False, None), ], ) -def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, +def test_matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): import torch @@ -103,18 +112,39 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu with_scaling=with_scaling, with_zeros=with_zeros, zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, ) matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) - inputs = [] - inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) - inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5) + def map_torch_type(intype): - output_bitblas = matmul.forward(*inputs) - output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) - torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) + typemap = { + 'e4m3_float8': torch.float8_e4m3fn, + 'e5m2_float8': torch.float8_e5m2, + } + if intype in typemap: + return typemap[intype] + else: + return getattr(torch, intype) + + numpytype_a = map_torch_type(A_dtype) + numpytype_b = map_torch_type(W_dtype) + + torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda() + ref_out = torch.matmul(torch_a.to(torch.float32), + torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul( + torch_a.to(torch.float32), torch_b.to(torch.float32)) + ref_out = ref_out.to(torch.float16) + bitblas_out = torch.empty_like(ref_out) + matmul.forward(torch_a, torch_b, output=bitblas_out) + print("torch_ref_out", ref_out) + print("bitblas_out", bitblas_out) + + torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1) # fmt: on