Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 169 additions & 107 deletions torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,69 +14,116 @@ def _quantize_array(
n_bits = 8
int_max = 2**(n_bits - 1) - 1
scale = (x_abs_max_val / int_max).T # [bs_block_size, 1]
# Need to explicitly cast to f32 because Mosaic can't directly jnp.round a
# bf16 array.
# It seems x/0 in Pallas generates inf/-inf instead of an exception.
x_int = jnp.round((x / scale).astype(jnp.float32)).astype(jnp.int8)
return x_int, scale.astype(x.dtype)
x_int = jnp.round(x / scale).astype(jnp.int8)
return x_int, scale.astype(jnp.float32)


def unfold_args(args: tuple[jax.Array | bool, ...], fn_args: tuple[bool, ...],
fn):
if len(args) == 0:
fn(*fn_args)
else:
arg = args[0]
if isinstance(arg, bool):
unfold_args(args[1:], fn_args + (arg,), fn)
else:
assert arg.dtype == jnp.bool and arg.size == 1
lax.cond(
arg,
lambda: unfold_args(args[1:], fn_args + (True,), fn),
lambda: unfold_args(args[1:], fn_args + (False,), fn),
)


def matmul_kernel(
x_ref, # (batch_block_size, in_block_size)
w_ref, # (out_block_size, in_block_size)
scalar_ref, # (1, out_block_size)
x_abs_max_val, # (1, batch_block_size)
out_ref, # (batch_block_size, out_block_size)
acc_ref, # (batch_block_size, out_block_size)
x_ref: jax.Array, # (batch_block_size, in_block_size)
w_ref: jax.Array, # (out_block_size, in_block_size)
scalar_ref: jax.Array, # (1, out_block_size)
x_abs_max_ref: jax.Array, # (1, batch_block_size)
out_ref: jax.Array, # (batch_block_size, out_block_size)
acc_scratch: jax.Array, # (batch_block_size, out_block_size)
q_x_scratch: jax.Array, # (batch_block_size, in_block_size)
x_scale_scratch: jax.Array, # (batch_block_size, 1)
*,
quantize_activation,
batch_block_size,
out_block_size,
in_block_size,
quantize_activation: bool,
save_acc: bool,
save_q_x: bool,
batch_block_size: int,
out_block_size: int,
in_block_size: int,
):
bs_idx, out_idx, in_idx = pl.program_id(0), pl.program_id(1), pl.program_id(2)
nsteps = pl.num_programs(2)
n_in = pl.num_programs(2)
x_ref_dtype = x_ref.dtype
assert x_ref.shape == (batch_block_size,
in_block_size), "x_ref shape is not correct"
assert w_ref.shape == (out_block_size,
in_block_size), "w_ref shape is not correct"
assert scalar_ref.shape == (1,
out_block_size), "scalar_ref shape is not correct"
assert x_abs_max_val.shape == (
assert x_abs_max_ref.shape == (
1, batch_block_size), "x_max_val shape is not correct"
assert out_ref.shape == (batch_block_size,
out_block_size), "out_ref shape is not correct"
assert acc_ref.shape == (batch_block_size,
out_block_size), "acc_ref shape is not correct"

@pl.when(in_idx == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)

if quantize_activation:
x, x_scale = _quantize_array(x_ref[...], x_abs_max_val[...])
acc_ref[...] += jax.lax.dot_general(
x,
w_ref[...],
(((1,), (1,)), ((), ())),
preferred_element_type=jnp.int32,
)

if save_q_x:
assert quantize_activation
assert q_x_scratch is not None
assert x_scale_scratch is not None
quant = out_idx == 0
else:
acc_ref[...] += jax.lax.dot_general(
x_ref[...],
w_ref[...],
(((1,), (1,)), ((), ())),
)

@pl.when(in_idx == nsteps - 1)
def _():
acc = acc_ref[...]
scalar = scalar_ref[...]
acc *= scalar
assert q_x_scratch is None
assert x_scale_scratch is None
quant = quantize_activation

if save_acc:
assert acc_scratch is not None
is_first_step = in_idx == 0
is_last_step = in_idx == n_in - 1
else:
assert acc_scratch is None
is_first_step = True
is_last_step = True

def matmul_body(quant, is_first_step, is_last_step):
if quantize_activation:
acc *= x_scale
out_ref[...] = acc.astype(x_ref_dtype)
if quant:
q_x_tmp, x_scale_tmp = _quantize_array(x_ref[...], x_abs_max_ref[...])
if save_q_x:
q_x_scratch[...] = q_x_tmp
x_scale_scratch[...] = x_scale_tmp
else:
assert save_q_x
q_x_tmp = q_x_scratch[...]
if is_last_step:
x_scale_tmp = x_scale_scratch[...]

acc = jax.lax.dot_general(
q_x_tmp,
w_ref[...],
(((1,), (1,)), ((), ())),
preferred_element_type=jnp.int32,
)
else:
acc = jax.lax.dot_general(
x_ref[...],
w_ref[...],
(((1,), (1,)), ((), ())),
)

if not is_first_step:
acc += acc_scratch[...]

if is_last_step:
acc *= scalar_ref[...]
if quantize_activation:
acc *= x_scale_tmp
out_ref[...] = acc.astype(x_ref_dtype)
else:
assert save_acc
acc_scratch[...] = acc

unfold_args((quant, is_first_step, is_last_step), (), matmul_body)


def _next_multiple(x, multiple):
Expand Down Expand Up @@ -159,10 +206,22 @@ def quantized_matmul_int8(
# Within the kernel, it will use some extra VMEM for computation or vreg spills.
vmem_used = vmem_to_be_transferred * 2
vmem_limit_bytes = min(vmem_used * 2, 96 * 1024 * 1024)

n_bs = padded_bs // batch_block_size
n_out = padded_out_features // out_block_size
n_in = padded_in_features // in_block_size

save_acc = n_in > 1
# Remove redundant input quantization logic by caching quantized input.
# For best performance, only enable this behavior when single input block is used per batch.
save_q_x = quantize_activation and n_in == 1 and n_out > 1

kernel = pl.pallas_call(
functools.partial(
matmul_kernel,
quantize_activation=quantize_activation,
save_acc=save_acc,
save_q_x=save_q_x,
batch_block_size=batch_block_size,
out_block_size=out_block_size,
in_block_size=in_block_size),
Expand All @@ -181,15 +240,18 @@ def quantized_matmul_int8(
out_specs=pl.BlockSpec((batch_block_size, out_block_size),
lambda b, o, i: (b, o)),
scratch_shapes=[
pltpu.VMEM((batch_block_size, out_block_size), acc_dtype)
pltpu.VMEM((batch_block_size,
out_block_size), acc_dtype) if save_acc else None,
pltpu.VMEM((batch_block_size,
in_block_size), jnp.int8) if save_q_x else None,
pltpu.VMEM(
(batch_block_size, 1), jnp.float32) if save_q_x else None,
],
grid=(padded_bs // batch_block_size,
padded_out_features // out_block_size,
padded_in_features // in_block_size),
grid=(n_bs, n_out, n_in),
),
out_shape=jax.ShapeDtypeStruct((padded_bs, padded_out_features), x.dtype),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary"),
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
vmem_limit_bytes=vmem_limit_bytes,
),
)
Expand Down Expand Up @@ -217,70 +279,70 @@ def quantized_matmul_int8(
# - out_block_size
# - in_block_size
TUNED_BLOCK_SIZES = {
(6, 1024, 1280, 8192, 'bfloat16', True): (1024, 1280, 2048),
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 3584, 4096),
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 4096, 2048),
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 1024, 4096),
(6, 1024, 6144, 4096, 'bfloat16', True): (1024, 1536, 4096),
(6, 1024, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192),
(6, 1024, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
(6, 1024, 8192, 3584, 'bfloat16', True): (1024, 2048, 3584),
(6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 128, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
(6, 128, 4096, 14336, 'bfloat16', True): (128, 1024, 3584),
(6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
(6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
(6, 2048, 6144, 4096, 'bfloat16', True): (2048, 512, 4096),
(6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096),
(6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512),
(6, 128, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
(6, 128, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
(6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
(6, 128, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
(6, 16, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 16, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
(6, 16, 4096, 14336, 'bfloat16', True): (128, 1024, 3584),
(6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096),
(6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
(6, 16, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
(6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
(6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
(6, 256, 6144, 4096, 'bfloat16', True): (256, 512, 4096),
(6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096),
(6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096),
(6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512),
(6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
(6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096),
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
(6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096),
(6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336),
(6, 1024, 6144, 4096, 'bfloat16', True): (1024, 768, 4096),
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096),
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 2048, 4096),
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 256, 14336),
(6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
(6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
(6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
(6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
(6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
(6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 16, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
(6, 16, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
(6, 2048, 1280, 8192, 'bfloat16', True): (512, 1280, 8192),
(6, 2048, 28672, 4096, 'bfloat16', True): (1024, 4096, 4096),
(6, 2048, 4096, 14336, 'bfloat16', True): (1024, 4096, 2048),
(6, 2048, 4096, 4096, 'bfloat16', True): (1024, 2048, 4096),
(6, 2048, 6144, 4096, 'bfloat16', True): (1024, 3072, 4096),
(6, 2048, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192),
(6, 64, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
(6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
(6, 128, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 128, 8192, 3584, 'bfloat16', True): (128, 8192, 512),
(6, 256, 1280, 8192, 'bfloat16', True): (256, 256, 8192),
(6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024),
(6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192),
(6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512),
(6, 16, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 512, 1280, 8192, 'bfloat16', True): (512, 256, 8192),
(6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024),
(6, 512, 7168, 8192, 'bfloat16', True): (512, 512, 8192),
(6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584),
(6, 1024, 1280, 8192, 'bfloat16', True): (1024, 256, 8192),
(6, 1024, 8192, 1024, 'bfloat16', True): (1024, 4096, 1024),
(6, 1024, 7168, 8192, 'bfloat16', True): (1024, 512, 8192),
(6, 1024, 8192, 3584, 'bfloat16', True): (1024, 1024, 3584),
(6, 2048, 1280, 8192, 'bfloat16', True): (2048, 256, 8192),
(6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
(6, 2048, 8192, 3584, 'bfloat16', True): (1024, 2048, 3584),
(6, 256, 1280, 8192, 'bfloat16', True): (256, 1280, 2048),
(6, 256, 28672, 4096, 'bfloat16', True): (256, 1792, 4096),
(6, 256, 4096, 14336, 'bfloat16', True): (256, 1024, 3584),
(6, 256, 4096, 4096, 'bfloat16', True): (256, 1024, 4096),
(6, 256, 6144, 4096, 'bfloat16', True): (256, 1024, 4096),
(6, 256, 7168, 8192, 'bfloat16', True): (256, 1024, 4096),
(6, 256, 8192, 1024, 'bfloat16', True): (256, 4096, 1024),
(6, 256, 8192, 3584, 'bfloat16', True): (256, 1024, 3584),
(6, 32, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 32, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
(6, 32, 4096, 14336, 'bfloat16', True): (128, 1024, 3584),
(6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
(6, 32, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
(6, 16, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
(6, 2048, 7168, 8192, 'bfloat16', True): (2048, 256, 8192),
(6, 2048, 8192, 3584, 'bfloat16', True): (2048, 512, 3584),
(6, 32, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 32, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
(6, 32, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 32, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
(6, 512, 1280, 8192, 'bfloat16', True): (512, 1280, 2048),
(6, 512, 28672, 4096, 'bfloat16', True): (512, 3584, 4096),
(6, 512, 4096, 14336, 'bfloat16', True): (512, 4096, 1792),
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
(6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096),
(6, 512, 7168, 8192, 'bfloat16', True): (512, 1024, 8192),
(6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024),
(6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584),
(6, 64, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 64, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
(6, 64, 4096, 14336, 'bfloat16', True): (128, 512, 7168),
(6, 64, 4096, 4096, 'bfloat16', True): (128, 1024, 4096),
(6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
(6, 64, 7168, 8192, 'bfloat16', True): (128, 896, 4096),
(6, 64, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 64, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
(6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
}


Expand Down
Loading