Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 1cc769 to 321f41
15 changes: 15 additions & 0 deletions bitblas/base/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,18 @@ def get_arch(target: tvm.target.Target) -> TileDevice:
return CDNA(target)
else:
raise ValueError(f"Unsupported target: {target.kind.name}")


def is_ampere_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(isinstance(arch, CUDA))
conditions.append(arch.sm_version >= 80)
return all(conditions)


def is_volta_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(isinstance(arch, CUDA))
conditions.append(arch.sm_version >= 70)
conditions.append(arch.sm_version < 80)
return all(conditions)
10 changes: 8 additions & 2 deletions bitblas/ops/base_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Union, Callable
from typing import Union, Callable, List
from dataclasses import dataclass, field
from tvm.tl.transform import Simplify
from abc import ABC, abstractmethod
from bitblas.base.arch import TileDevice
from bitblas.base.roller.hint import Hint
from bitblas.tl.base_hint import BaseTLHint


# Decorator to simplify the output of a function
Expand Down Expand Up @@ -54,7 +56,7 @@ def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]):
return stmt

@abstractmethod
def with_default_config(self):
def with_default_config(self) -> PrimFunc:
pass

@abstractmethod
Expand All @@ -65,6 +67,10 @@ def apply_config(
):
pass

def serialze_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]:
# Convert Roller Hints to TileLang Hints
raise NotImplementedError

@property
def common_header(self):
# TODO(lei): For HIP Backend it should be different
Expand Down
51 changes: 24 additions & 27 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from tvm.target import Target
import operator
from functools import reduce
from bitblas.base.arch.cuda import CUDA
from bitblas.base.arch.cdna import CDNA
from bitblas.base.roller.hint import Hint
from typing import Any, Literal, Optional, Tuple, Union
from ..operator import OperatorConfig, Operator, OPExecutorCPU, BaseKernelNameGenerator
Expand Down Expand Up @@ -290,24 +288,25 @@ def generate(self, hint=None) -> str:

precision_str = (f"{A_dtype}x{W_dtype}")
kernel_name = "_".join([kernel_name, shape_str, precision_str])

# if config.with_scaling:
# kernel_name += "Scale"

# if config.with_zeros:
# if config.zeros_mode == "original":
# kernel_name += "OriginalZeros"
# elif config.zeros_mode == "rescale":
# precision_str += "RescaleZeros"
# elif config.zeros_mode == "quantized":
# precision_str += "QuantizedZeros"
# else:
# raise ValueError(f"Unsupported zeros mode: {config.zeros_mode}")

# if config.propagate_a is not TransformKind.NonTransform:
# kernel_name += f"_pa{config.propagate_a.value}"
# if config.propagate_b is not TransformKind.NonTransform:
# kernel_name += f"_pb{config.propagate_b.value}"
'''
if config.with_scaling:
kernel_name += "Scale"

if config.with_zeros:
if config.zeros_mode == "original":
kernel_name += "OriginalZeros"
elif config.zeros_mode == "rescale":
precision_str += "RescaleZeros"
elif config.zeros_mode == "quantized":
precision_str += "QuantizedZeros"
else:
raise ValueError(f"Unsupported zeros mode: {config.zeros_mode}")

if config.propagate_a is not TransformKind.NonTransform:
kernel_name += f"_pa{config.propagate_a.value}"
if config.propagate_b is not TransformKind.NonTransform:
kernel_name += f"_pb{config.propagate_b.value}"
'''

kernel_name = "_".join([kernel_name, self.serialize_hint(hint)])
assert self.is_valid(kernel_name), "Kernel name invalid"
Expand Down Expand Up @@ -390,16 +389,12 @@ def dispatch_tir(self,
from_database: bool = False,
source_format: str = "uint",
enable_tuning: bool = True):
'''Dispatch the tir script implementation'''
if (target.kind.name == "cuda"):
self.arch = CUDA(target)
elif (target.kind.name == "hip"):
self.arch = CDNA(target)

if isinstance(self.M, Tuple):
self.dynamic_range = {"m": self.M}
self.ir_module["main"] = self.ir_module["main"].with_attrs(
{"opt_shapes": self.dynamic_range})
if self.is_tir_backend():
self.ir_module["main"] = self.ir_module["main"].with_attrs(
{"opt_shapes": self.dynamic_range})
else:
self.dynamic_range = None

Expand Down Expand Up @@ -600,6 +595,7 @@ def _select_implementation(self):
def _select_scheduler(self):
if is_native_compute(self.A_dtype, self.W_dtype):
return consistent_scheduler(
arch=self.arch,
M=self.M,
N=self.N,
K=self.K,
Expand All @@ -613,6 +609,7 @@ def _select_scheduler(self):
)
else:
return weight_dequantize_scheduler(
arch=self.arch,
M=self.M,
N=self.N,
K=self.K,
Expand Down
104 changes: 99 additions & 5 deletions bitblas/ops/general_matmul/tilelang/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

from .matmul_tensorcore import (
MatmulScheduler, # noqa: F401
MatmulBlockScheduler, # noqa: F401
MatmulFineGrainScheduler, # noqa: F401
MatmulWeightPropagationScheduler, # noqa: F401
)
Expand All @@ -22,6 +22,11 @@
MatmulINT4WeightPropagationScheduler, # noqa: F401
)

from bitblas.base.roller import TileDevice
from bitblas.base.arch import (
is_ampere_arch,
is_volta_arch,
)
from bitblas.ops.common import TransformKind
from typing import Union

Expand All @@ -40,7 +45,52 @@ def is_non_transform_kind(kind) -> bool:
return kind == TransformKind.NonTransform


def select_scheduler(
def volta_select_schduler(
M=None,
N=16384,
K=16384,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
propagate_a: Union[int, TransformKind] = TransformKind.NonTransform,
propagate_b: Union[int, TransformKind] = TransformKind.NonTransform,
):
trans_A, trans_B = parse_layout(layout)
if isinstance(propagate_a, int):
propagate_a = TransformKind(propagate_a)
if isinstance(propagate_b, int):
propagate_b = TransformKind(propagate_b)

def check_if_not_supported():
conditions = [True]
conditions.append(propagate_a == TransformKind.NonTransform)
conditions.append(propagate_b == TransformKind.NonTransform)
conditions.append(trans_A is False)
conditions.append(trans_B is True)
conditions.append(in_dtype in ["int8", "float16", "float32"])
conditions.append(accum_dtype in ["int32", "float32"])
return all(conditions)

if not check_if_not_supported():
raise ValueError(f"Unsupported configuration: {layout}, {propagate_a}, {propagate_b}")

Scheduler = MatmulFineGrainSIMTScheduler
return Scheduler(
M=M,
N=N,
K=K,
trans_A=trans_A,
trans_B=trans_B,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
with_bias=with_bias,
)


def ampere_select_scheduler(
M=None,
N=16384,
K=16384,
Expand All @@ -60,8 +110,6 @@ def select_scheduler(
propagate_a = TransformKind(propagate_a)
if isinstance(propagate_b, int):
propagate_b = TransformKind(propagate_b)
if with_bias:
raise NotImplementedError

trans_A, trans_B = parse_layout(layout)

Expand Down Expand Up @@ -102,6 +150,7 @@ def is_int4_dtype(dtype):
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
with_bias=with_bias,
)
if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b):
Scheduler = MatmulFineGrainScheduler if not is_int4_dtype(
Expand All @@ -115,9 +164,10 @@ def is_int4_dtype(dtype):
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
with_bias=with_bias,
)
elif can_apply_block_scheduler(propagate_a, propagate_b):
return MatmulScheduler(
return MatmulBlockScheduler(
M=M,
N=N,
K=K,
Expand All @@ -126,6 +176,50 @@ def is_int4_dtype(dtype):
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
with_bias=with_bias,
)
else:
raise ValueError(f"Unsupported configuration: {layout}, {propagate_a}, {propagate_b}")


def select_scheduler(
arch: TileDevice,
M=None,
N=16384,
K=16384,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_bias=False,
layout="nt",
propagate_a: Union[int, TransformKind] = TransformKind.NonTransform,
propagate_b: Union[int, TransformKind] = TransformKind.NonTransform,
):
if is_ampere_arch(arch):
return ampere_select_scheduler(
M=M,
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
with_bias=with_bias,
layout=layout,
propagate_a=propagate_a,
propagate_b=propagate_b,
)
elif is_volta_arch(arch):
return volta_select_schduler(
M=M,
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
with_bias=with_bias,
layout=layout,
propagate_a=propagate_a,
propagate_b=propagate_b,
)
else:
raise ValueError(f"Unsupported arch: {arch.name}")
Loading