Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
This repository was archived by the owner on Feb 24, 2026. It is now read-only.

[BUG] Vectorized Bias Add with AtomicAdd may lead to unknown bugs #271

@LeiWang1999

Description

@LeiWang1999
  #pragma unroll
  for (int i_10 = 0; i_10 < 4; ++i_10) {
    __syncthreads();
    uint4 __1;
      uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072));
      uint4 v__1 = *(uint4*)(Bias + (((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 1) * 8)));
      ((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
      ((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
      ((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
      ((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
      ((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
      ((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
      ((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
      ((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
    *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072)) = __1;
  }
  __syncthreads();
  #pragma unroll
  for (int i_11 = 0; i_11 < 16; ++i_11) {
    atomicAddx2((&(C[(((((((int)blockIdx.y) * 65536) + (i_11 * 4096)) + ((((int)threadIdx.x) >> 5) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 31) * 2))])), (&(((half_t*)buf_dyn_shmem)[(((((((i_11 >> 2) * 1024) + (((((int)threadIdx.x) & 31) >> 3) * 256)) + ((i_11 & 3) * 64)) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 3072)])));
  }

have correctness issues while without atomicAdd it's correct.

  for (int i_14 = 0; i_14 < 4; ++i_14) {
    __syncthreads();
    uint4 __1;
      uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
      uint4 v__1 = *(uint4*)(Bias + ((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) & 7) * 8)));
      ((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
      ((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
      ((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
      ((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
      ((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
      ((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
      ((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
      ((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
    *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072)) = __1;
  }
  __syncthreads();
  #pragma unroll
  for (int i_15 = 0; i_15 < 4; ++i_15) {
    *(uint4*)(C + (((((((int)blockIdx.y) * 65536) + (i_15 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8))) = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_15 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
  }

currently we disable atomicAdd when we have bias to skip this situation.

Reproduce:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
import bitblas.testing
from bitblas import Linear as BitBLASLinear
import torch
import time
import numpy as np
import torch.nn as nn

torch.manual_seed(0)
bitblas.set_log_level("DEBUG")


def correctness_consistent(m, in_features, out_features, bias):
    linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda())
    linear_bitblas = BitBLASLinear(
        in_features,
        out_features,
        bias=bias,
        A_dtype="float16",
        W_dtype="float16",
        accum_dtype="float16",
        out_dtype="float16",
        opt_M=m,
    ).cuda()

    with torch.no_grad():
        linear_bitblas.load_and_transform_weight(linear_torch.weight.clone())
        if bias:
            linear_bitblas.bias = nn.Parameter(linear_torch.bias.clone())

    with torch.no_grad():
        if not isinstance(m, int):
            # When m is a list, average m
            m = sum(m) // len(m)
        input_data = torch.randn(m, in_features, dtype=torch.float16).cuda()
        output_torch = linear_torch(input_data)
        output_bitblas = linear_bitblas(input_data)
    print(output_torch)
    print(output_bitblas)
    bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2)


def test_correctness_consistent():
    correctness_consistent(1, 1024, 1024, False)
    correctness_consistent(1, 1024, 1024, True)
    correctness_consistent(1024, 1024, 1024, True)
    correctness_consistent([1, 1024], 1024, 1024, True)


def correctness_weight_only_dequantize(
    m,
    in_features,
    out_features,
    bias,
    W_dtype,
    group_size,
    with_scaling,
    with_zeros,
    zeros_mode,
):
    import numpy as np
    from bitblas.quantization.utils import general_compress
    from bitblas.cache import global_operator_cache

    global_operator_cache.clear()
    linear_bitblas = BitBLASLinear(
        in_features,
        out_features,
        bias=bias,
        A_dtype="float16",
        W_dtype=W_dtype,
        accum_dtype="float16",
        out_dtype="float16",
        group_size=group_size,
        with_scaling=with_scaling,
        with_zeros=with_zeros,
        opt_M=m,
    ).cuda()
    if not isinstance(m, int):
        # average m
        m = sum(m) // len(m)
    input_shape = (m, in_features)
    weight_shape = (out_features, in_features)
    output_shape = (m, out_features)
    inputs = []
    inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
    source_format, bit = (
        linear_bitblas.bitblas_matmul.source_format,
        linear_bitblas.bitblas_matmul.bit,
    )

    maxq = 2**(bit - 1)
    zeros = maxq
    if source_format == "uint":
        inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
    elif source_format == "int":
        inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
    else:
        raise NotImplementedError

    inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())

    intweight = inputs[1]
    intweight = intweight.cpu().to(torch.int8)
    if source_format == "int":
        intweight = intweight + maxq
    if with_zeros:
        inputs[1] = inputs[1] - zeros
    bias_tensor = torch.rand((output_shape[-1],), dtype=torch.float16).cuda()
    ref_result = torch.matmul(inputs[0], (inputs[1].t()).to(torch.float16))
    if bias:
        ref_result = ref_result + bias_tensor

    with torch.no_grad():
        permuted_inputs = []
        permuted_inputs.append(inputs[0])
        if linear_bitblas.bitblas_matmul.weight_transform is not None:
            permuted_inputs.append(
                linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda())
        else:
            permuted_inputs.append(inputs[1])
        linear_bitblas.qweight.data = permuted_inputs[-1].clone()
        if with_scaling:
            if group_size == -1:
                group_size = in_features
            permuted_inputs.append(
                torch.ones([out_features, in_features // group_size], dtype=torch.float16).cuda())
            linear_bitblas.scales.data = permuted_inputs[-1].clone()
        if with_zeros:
            if zeros_mode == "original":
                permuted_inputs.append(
                    torch.ones([out_features, in_features // group_size],
                               dtype=torch.float16).cuda() * zeros)
            elif zeros_mode == "rescale":
                original_zeros = (
                    torch.ones([out_features, in_features // group_size],
                               dtype=torch.float16).cuda() * zeros)
                scaled_zeros = original_zeros * permuted_inputs[-1]
                permuted_inputs.append(scaled_zeros)
            elif zeros_mode == "quantized":
                original_zeros = (
                    torch.ones([in_features // group_size, out_features], dtype=torch.int8).cuda() *
                    zeros)
                qzeros = general_compress(
                    original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
                permuted_inputs.append(torch.from_numpy(qzeros).cuda())
            else:
                raise NotImplementedError
            linear_bitblas.zeros.data = permuted_inputs[-1].clone()
        if bias:
            permuted_inputs.append(bias_tensor)
            linear_bitblas.bias.data = bias_tensor.clone()

    with torch.no_grad():
        output_bitblas = linear_bitblas(inputs[0])

    rtol = 1e0
    atol = 1e0
    if zeros_mode == "original":
        rtol = 1e2
        atol = 1e2
    print(output_bitblas)
    print(ref_result)
    torch.testing.assert_close(output_bitblas, ref_result, rtol=rtol, atol=atol)


def test_correctness_weight_only_dequantize():
    correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original")
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original")
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale")


def profile(model, input_data):
    model = model.cuda()
    model.eval()

    def get_runtime(num_repeats=1):
        tic = time.time()
        for _ in range(num_repeats):
            _ = model(input_data)
        torch.cuda.synchronize()
        return (time.time() - tic) * 1000 / num_repeats

    with torch.no_grad():
        # print("Warming up ...")
        st = time.time()
        while time.time() - st < 1.0:
            get_runtime()  # warmup
        warmup_runtime = get_runtime()
        num_repeats = max(1, int(1000 / warmup_runtime))
        times = get_runtime(num_repeats)
    return np.mean(times)


if __name__ == "__main__":
    # bitblas.testing.main()
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions