Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
This repository was archived by the owner on Feb 24, 2026. It is now read-only.

accuracy and performance of bfloat16 with bitblas linear #161

@AbedKhateeb2

Description

@AbedKhateeb2

I tried to run bfloat16 linear of bitblas
but I got different result

output:
qunatizing /decoder/block/0/layer/0/SelfAttention/k
torch linear took by avg 7.802581787109375e-05
BitBLAS Operator found in global_operator_cache.
bitblas linear took to init : 1.157283067703247 sec
bitblas linear took by avg 7.474946975708008e-05
torch compare : tensor(2.2344, device='cuda:0', dtype=torch.bfloat16)

the linear layer is from pretrained model
the model was trained with bf16
cuda version : 12.1
gpu : A10G
ubuntu
bitblas version bitblas==0.0.1.dev15

from bitblas import Linear as BitBLASLinear
print(f"qunatizing {name}")
  in_features = linear_layer.in_features
  out_features = linear_layer.out_features


      opt_M = 1

  class Custom( BitBLASLinear):
      
      def forward(self, A):
          out = super().forward(A)
          out = out.to(torch.bfloat16)
          return out
  input_tensor = torch.rand(opt_M, in_features).to(torch.bfloat16).cuda()
  st = time.time()
  while time.time() - st < 1.0:
      linear_layer(input_tensor)
  times = 1000
  with torch.no_grad():
      start_time = time.time()
      for _ in range(times):
          output_torch = linear_layer(input_tensor)
      end_time = time.time()
  print(f"torch linear took by avg {(end_time-start_time)/times}")
  start_time = time.time()
  # bitblas_linear = Int8Linear(linear_module=linear_torch)
  # BitBLASLinear.STORAGE_DTYPE='bfloa16'
  bitblas_linear = Custom(linear_layer.in_features, linear_layer.out_features, bias=linear_layer.bias is not None, opt_M=opt_M, accum_dtype='float32', A_dtype='bfloat16', W_dtype='bfloat16')
  bitblas_linear.load_and_transform_weight(linear_layer.weight.clone())
  if linear_layer.bias is not None:
      bitblas_linear.bias.data = linear_layer.bias.data.clone()
      
  st = time.time()
  while time.time() - st < 1.0:
      bitblas_linear(input_tensor)
  end_time = time.time()
  print(f"bitblas linear took to init : {(end_time-start_time)} sec")
  bitblas_linear.cuda()
  with torch.no_grad():
      start_time = time.time()
      for _ in range(times):
          output_bitblas = bitblas_linear(input_tensor)
      end_time = time.time()
  print(f"bitblas linear took by avg {(end_time-start_time)/times}")

  print("torch compare : ",torch.mean(torch.abs(output_torch.to(torch.bfloat16)-output_bitblas.to(torch.bfloat16))))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions