Skip to content

Add FP8 kernel acceleration for compressed-tensors quantized models#45699

Open
jiqing-feng wants to merge 55 commits into
huggingface:mainfrom
jiqing-feng:fp8
Open

Add FP8 kernel acceleration for compressed-tensors quantized models#45699
jiqing-feng wants to merge 55 commits into
huggingface:mainfrom
jiqing-feng:fp8

Conversation

@jiqing-feng

@jiqing-feng jiqing-feng commented Apr 29, 2026

Copy link
Copy Markdown
Contributor

CI

What does this PR do?

This PR adds native FP8 matmul kernel support for compressed-tensors FP8 quantized models in transformers. Previously, compressed-tensors FP8 models were loaded via the compressed-tensors library and dequantized back to FP16/BF16 for inference. With this change, FP8 weights are kept in FP8 format and inference uses hardware-accelerated FP8 matmul kernels (torch._scaled_mm on XPU, fbgemm.f8f8bf16_rowwise on CUDA).

Key changes:

New file: src/transformers/integrations/compressed_tensors_fp8.py

  • CTFP8Linear: FP8 linear layer that stores weights in FP8 and uses row-wise FP8 matmul kernels. Activations are dynamically quantized per-row via quantize_fp8_per_row.
  • Weight converters (CompressedTensorsScaleConvert, CompressedTensorsFp8Dequantize) to handle the checkpoint format conversion (e.g. weight_scaleweight_scale_inv).
  • CTFP8PerRowQuantize: Online quantization support — quantize BF16 weights to FP8 per-row on-the-fly during model loading.

Modified: src/transformers/quantizers/quantizer_compressed_tensors.py

  • CompressedTensorsHfQuantizer now detects FP8 quantization configs (float type, num_bits=8) and automatically routes to the FP8 kernel path when GPU/XPU is available. Falls back to the default compressed-tensors dequantize path on CPU.
  • Added get_weight_conversions() and get_quantize_ops() to support both pre-quantized loading and online quantization.
  • No changes to the non-FP8 code path — existing INT8/INT4 compressed-tensors models are unaffected.

Modified: src/transformers/quantizers/auto.py

  • Minor formatting change (no functional change).

Supported models

Usage

Pre-quantized model (no config needed)

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Online quantization

from transformers import AutoModelForCausalLM, CompressedTensorsConfig
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs, QuantizationType, QuantizationStrategy

ct_config = CompressedTensorsConfig(
    config_groups={
        "group_0": QuantizationScheme(
            weights=QuantizationArgs(
                num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.CHANNEL,
            ),
        ),
    },
    run_compressed=True,
)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct",
    quantization_config=ct_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Devices

  • XPU (Intel Data Center Max / Arc): uses torch._scaled_mm
  • CUDA (SM89+): uses fbgemm.f8f8bf16_rowwise
  • CPU: falls back to default compressed-tensors dequantize path

@sywangyi

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng changed the title Fp8 Add FP8 kernel acceleration for compressed-tensors quantized models Apr 29, 2026
@Rocketknight1

Copy link
Copy Markdown
Member

cc @SunMarc

@SunMarc SunMarc left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left a comment !

Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng marked this pull request as ready for review April 30, 2026 03:11
@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Please check it the integration is ok. I'll clean the tests and doc after you approved the integration.

@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please review the PR? Thanks!

@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @Rocketknight1 . It seems that @SunMarc does not have bandwidth to review this PR. Would you please help to review the PR? Thanks!

Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py
Comment thread src/transformers/quantizers/auto.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the integration part ! cc @SunMarc for the quantizer part

Comment thread src/transformers/quantizers/quantizer_compressed_tensors.py Outdated
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Nit on walking / updating existing conversion

Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py
Comment thread src/transformers/quantizers/quantizer_compressed_tensors.py Outdated
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @ArthurZucker . Would you please review my new commit to check if it fixed your comments? Thanks!

@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please let me know what need to be changed before merging? Thanks!

@SunMarc SunMarc left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work !

I think it will still be better if we don't do online quantization no ? Like users can just use compressed tensors to do it no ? If they want to use online fp8, they have finegrained-fp8 for that and we can update the support if needed. The thing is that I don't want to introduce a new way to create CT checkpoints + maintenance overhead for reverse ops that needs to match CT implementation and ours. Right now, what you did is to dequantize back to BF16 if someone wants to save the online quantized model.

Also, for the FP8 kernels, we should probably add a new arg like the other quantization method called dequantize. run_compressed don't work anymore as you saw. CT just decompresses the model on the first forward.

x = input.reshape(-1, input.shape[-1])
x_quantized, x_scale = _quantize_fp8_per_row(x)

weight_scale_float32 = self.weight_scale_inv.to(torch.float32)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight_scale will be in float32 if you specified it in the modeling so we don't need to do that

Suggested change
weight_scale_float32 = self.weight_scale_inv.to(torch.float32)

Comment on lines +123 to +124
scale_b = weight_scale_float32.t()
if scale_b.shape[-1] == 1 and self.out_features > 1:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe define a variable called is_per_tensor ?


if _can_use_fp8_kernel():
# XPU or CUDA SM89+: FP8 kernel path (quantize activation + scaled_mm)
x = input.reshape(-1, input.shape[-1])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can reshape the input before as we do for both path ?


module_kwargs = {} if pre_quantized else {"dtype": None}
if isinstance(module, nn.Linear):
with torch.device("meta"):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this normally as this method is already under a context manager that does this

Suggested change
with torch.device("meta"):

Comment on lines +28 to +38
def _is_fp8_config(quantization_config: CompressedTensorsConfig) -> bool:
"""Check if a CompressedTensorsConfig describes FP8 quantization."""
ct_qconfig = quantization_config.quantization_config
if ct_qconfig is None:
return False
for group in ct_qconfig.config_groups.values():
weights = group.weights
if weights is not None and weights.type == "float" and weights.num_bits == 8:
return True
return False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move that to compressed_tensors config class maybe ?

Comment on lines +171 to +172
if self.is_fp8:
return False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depends if we dequantize or not also

Comment on lines +181 to +182
if self.is_fp8:
return False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines +218 to +227
class CompressedTensorsActivationScaleConvert(ConversionOps):
"""Rename compressed-tensors `input_scale` to `activation_scale`."""

def convert(self, input_dict, **kwargs):
scale = input_dict["input_scale"][0]
return {"activation_scale": scale.to(torch.float32)}

@property
def reverse_op(self):
return _IdentityOp()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just keep the same name no ?

Comment on lines +183 to +193
class CompressedTensorsScaleConvert(ConversionOps):
"""Convert compressed-tensors `weight_scale` to `weight_scale_inv`.

In compressed-tensors, `weight_scale` is the dequantization multiplier:
bf16_weight = fp8_weight * weight_scale

In our CompressedTensorsFP8Linear, `weight_scale_inv` has the same semantics (it's
multiplied with the FP8 weight to get the dequantized value), so no inversion is needed.
The conversion also reshapes the scale: scalar → (1, 1), 1D (N,) → (N, 1).
"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we can keep the same name no ?

Comment on lines +230 to +240
class CompressedTensorsFp8Dequantize(ConversionOps):
"""Dequantize compressed-tensors FP8 weights back to BF16.

Folds the per-channel / per-tensor ``weight_scale`` into the FP8 weight,
producing a BF16 tensor. Prepended to a converter chain for layers that
cannot stay in FP8 (e.g. merged MoE experts, which are not ``nn.Linear``):
it pairs each weight with its sibling scale *by index* and preserves the
per-expert list structure so the downstream merge / concat ops still see
one tensor per expert.
"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this ? Like it is not really useful to online quantize a model to finally save it in bf16 no ?

With compressed tensors, it will dequantize the model in any case no if you specify run_compressed=False no for a quantized model.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can you explain the moe bit, i didn't fully understand

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc, thanks for the review! Addressed everything; kept the MoE dequant-before-merge on purpose (explained below).

Done

  • No more online quant — removed CompressedTensorsFP8PerRowQuantize; get_quantize_ops returns None and param_needs_quantization is False for FP8. Online FP8 → use finegrained-fp8.
  • dequantize arg — added dequantize: bool = False to CompressedTensorsConfig (like the other methods), wired through to replace_with_compressed_tensors_fp8_linear. run_compressed no longer drives this.
  • (1) dropped the redundant .to(float32) (weight_scale is already f32).
  • (2) added an is_per_tensor variable.
  • (3) input is reshaped to 2D once, shared by both paths.
  • (4) removed with torch.device("meta").
  • (5) moved _is_fp8_configCompressedTensorsConfig.is_fp8.
  • (6) / (7) is_trainable / is_qat_trainable now return dequantize for FP8 (trainable only on the dequantized BF16 path).
  • (8) kept input_scale — the converter was dead code, removed it.
  • (9) kept weight_scale (no more weight_scale_inv); the converter only reshapes.

Side effect: CompressedTensorsFp8Dequantize.reverse_op is now _IdentityOp() (we never re-quantize on save).

The MoE dequant (comments 10 & 11)

MoE checkpoints store per-expert weights+scales, but transformers merges the experts (stack/cat) into a single 3D packed nn.Parameter — not an nn.Linear, so it can't hold a weight_scale, and the per-expert scales differ so they can't survive the merge. The only correct option is to dequantize each expert (weight * weight_scale) to BF16 before the merge: update_weight_conversions prepends CompressedTensorsFp8Dequantize to the merging converters, pairs weight+scale by expert index, drops the scales. Without it, MoE FP8 loading crashes. (Attention/router linears in the same model still stay FP8.) The dequantize=True flag reuses this same op to also fold plain linears to BF16 — that's the answer to comment 10.

Known limitation: since merged experts land in BF16, this PR doesn't save memory on MoE expert weights (the bulk of params) — only dense models and the attention/router part of MoE benefit. Not fundamental: experts could stay FP8 with (1) a 3D FP8 expert param, (2) a stacked per-expert scale tensor, (3) a grouped scaled-mm in the MoE forward. That's a bigger change, so I'd suggest a follow-up and keeping this dequant path as the correct baseline.

Happy to adjust naming or split differently — let me know!

@SunMarc SunMarc left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of comments ! Try to fix those but I can take over the PR at some point if needed. I really want this is be done correctly

Comment on lines 1116 to 1117
run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
emulate compressed model execution if True, otherwise use default submodule

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_compressed should be deprecated and we should use dequantize instead
dequantize = not run_compressed
we shouldn't use it anywhere else

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. run_compressed is deprecated (warns and maps to dequantize); dequantize is now the single source of truth for all internal logic.

Comment on lines +164 to +169

def get_quantize_ops(self):
# Online FP8 quantization is not supported (use finegrained-fp8 instead),
# so there is never an on-the-fly quantization op to apply.
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to define it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +211 to +233
from ..integrations.compressed_tensors_fp8 import CompressedTensorsFp8Dequantize

updated: list = []
for conv in weight_conversions:
# Only WeightConverter has ``.operations`` to extend; WeightRenaming
# rules just pass through untouched.
if not isinstance(conv, WeightConverter):
updated.append(conv)
continue

weight_sources = [p for p in conv.source_patterns if p.endswith(".weight")]
if weight_sources:
anchored_weight = [p + "$" for p in weight_sources]
scale_sources = [p[: -len(".weight")] + ".weight_scale$" for p in weight_sources]
other = [p for p in conv.source_patterns if not p.endswith(".weight")]
new_sources = anchored_weight + scale_sources + other
new_ops = [CompressedTensorsFp8Dequantize()] + list(conv.operations)
conv = WeightConverter(
source_patterns=new_sources,
target_patterns=conv._original_target_patterns,
operations=new_ops,
)
updated.append(conv)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that when we do fusion, we will have to dequantize this particular module ? If this is the case, maybe we should return a warning no ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added a logger.warning_once when merged (e.g. MoE expert) weights fall back to dequantized BF16.

def is_trainable(self):
# The FP8 kernel path is inference-only. With `dequantize=True` we don't enter
# it (`is_fp8` is False) and fall back to the regular compressed-tensors route.
if self.is_fp8:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we think of a better name compared to is_fp8 ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. The quantizer flag is now use_fp8_kernel (is_fp8 on the config still means "the checkpoint is FP8 type").

Comment on lines +61 to +62
if quantization_config.is_fp8 and not quantization_config.dequantize:
self.is_fp8 = True

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it might be better to put the quantization_config.dequantize in the is_fp8 method directly no ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the dequantize check off the quantizer and into the config, as you asked. I kept is_fp8 meaning purely "the checkpoint is FP8 type" (so it doesn't flip to False for an FP8 checkpoint that's being dequantized) and added a dedicated property CompressedTensorsConfig.use_fp8_kernel = is_fp8 and not dequantize. The quantizer now just reads use_fp8_kernel. This also covers your "better name than is_fp8" comment. Happy to fold it straight into is_fp8 instead if you prefer.

Comment on lines +229 to +240
def _scale_pattern_for(weight_pattern: str) -> str:
# Strip the optional ``$`` regex anchor so we can match the underlying name.
anchored = weight_pattern.endswith("$")
base = weight_pattern[:-1] if anchored else weight_pattern
if base.endswith(".weight"):
scale = base[: -len(".weight")] + ".weight_scale"
elif base == "weight":
scale = "weight_scale"
else:
scale = base + "_scale"
return scale + "$" if anchored else scale

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need all these cases ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, reduced to the single .weight -> .weight_scale case.

Comment on lines +217 to +226
class CompressedTensorsFp8Dequantize(ConversionOps):
"""Dequantize compressed-tensors FP8 weights back to BF16.

Folds the per-channel / per-tensor ``weight_scale`` into the FP8 weight,
producing a BF16 tensor. Prepended to a converter chain for layers that
cannot stay in FP8 (e.g. merged MoE experts, which are not ``nn.Linear``):
it pairs each weight with its sibling scale *by index* and preserves the
per-expert list structure so the downstream merge / concat ops still see
one tensor per expert.
"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, can you try to simplify this a bit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +92 to +93
if not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())):
self.skipTest("FP8 kernel path requires GPU or XPU")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't do that, use the flags

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, replaced the manual skipTest with @require_torch_accelerator.

# check perplexity
perplexity = torch.exp(outputs.loss)
self.assertLessEqual(perplexity, expected_perplexity)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add more integration tests. test the dequantize path. with your changes, maybe some fp8 tests don't behave the same also. Add a test also when saving the fp8 model, it should work again when reloading maybe. feels like the changes we did are quite minor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added test_tinyllama_fp8_dequantize (CPU dequantize path) and test_tinyllama_fp8_save_reload.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@github-actions

github-actions Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: compressed_tensors_integration

@github-actions

github-actions Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28495931842:2
Result: success | Jobs: 14 | Tests: 71,713 | Failures: 0 | Duration: 17h 31m

@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . I've fixed all your comments. Please check it. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants