Add FP8 kernel acceleration for compressed-tensors quantized models#45699
Add FP8 kernel acceleration for compressed-tensors quantized models#45699jiqing-feng wants to merge 55 commits into
Conversation
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
cc @SunMarc |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Hi @SunMarc . Please check it the integration is ok. I'll clean the tests and doc after you approved the integration. |
|
Hi @SunMarc . Would you please review the PR? Thanks! |
|
Hi @Rocketknight1 . It seems that @SunMarc does not have bandwidth to review this PR. Would you please help to review the PR? Thanks! |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
LGTM for the integration part ! cc @SunMarc for the quantizer part
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
ArthurZucker
left a comment
There was a problem hiding this comment.
Nice! Nit on walking / updating existing conversion
|
Hi @ArthurZucker . Would you please review my new commit to check if it fixed your comments? Thanks! |
|
Hi @SunMarc . Would you please let me know what need to be changed before merging? Thanks! |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks for your work !
I think it will still be better if we don't do online quantization no ? Like users can just use compressed tensors to do it no ? If they want to use online fp8, they have finegrained-fp8 for that and we can update the support if needed. The thing is that I don't want to introduce a new way to create CT checkpoints + maintenance overhead for reverse ops that needs to match CT implementation and ours. Right now, what you did is to dequantize back to BF16 if someone wants to save the online quantized model.
Also, for the FP8 kernels, we should probably add a new arg like the other quantization method called dequantize. run_compressed don't work anymore as you saw. CT just decompresses the model on the first forward.
| x = input.reshape(-1, input.shape[-1]) | ||
| x_quantized, x_scale = _quantize_fp8_per_row(x) | ||
|
|
||
| weight_scale_float32 = self.weight_scale_inv.to(torch.float32) |
There was a problem hiding this comment.
The weight_scale will be in float32 if you specified it in the modeling so we don't need to do that
| weight_scale_float32 = self.weight_scale_inv.to(torch.float32) |
| scale_b = weight_scale_float32.t() | ||
| if scale_b.shape[-1] == 1 and self.out_features > 1: |
There was a problem hiding this comment.
maybe define a variable called is_per_tensor ?
|
|
||
| if _can_use_fp8_kernel(): | ||
| # XPU or CUDA SM89+: FP8 kernel path (quantize activation + scaled_mm) | ||
| x = input.reshape(-1, input.shape[-1]) |
There was a problem hiding this comment.
maybe we can reshape the input before as we do for both path ?
|
|
||
| module_kwargs = {} if pre_quantized else {"dtype": None} | ||
| if isinstance(module, nn.Linear): | ||
| with torch.device("meta"): |
There was a problem hiding this comment.
we don't need this normally as this method is already under a context manager that does this
| with torch.device("meta"): |
| def _is_fp8_config(quantization_config: CompressedTensorsConfig) -> bool: | ||
| """Check if a CompressedTensorsConfig describes FP8 quantization.""" | ||
| ct_qconfig = quantization_config.quantization_config | ||
| if ct_qconfig is None: | ||
| return False | ||
| for group in ct_qconfig.config_groups.values(): | ||
| weights = group.weights | ||
| if weights is not None and weights.type == "float" and weights.num_bits == 8: | ||
| return True | ||
| return False | ||
|
|
There was a problem hiding this comment.
we can move that to compressed_tensors config class maybe ?
| if self.is_fp8: | ||
| return False |
There was a problem hiding this comment.
depends if we dequantize or not also
| if self.is_fp8: | ||
| return False |
| class CompressedTensorsActivationScaleConvert(ConversionOps): | ||
| """Rename compressed-tensors `input_scale` to `activation_scale`.""" | ||
|
|
||
| def convert(self, input_dict, **kwargs): | ||
| scale = input_dict["input_scale"][0] | ||
| return {"activation_scale": scale.to(torch.float32)} | ||
|
|
||
| @property | ||
| def reverse_op(self): | ||
| return _IdentityOp() |
There was a problem hiding this comment.
we can just keep the same name no ?
| class CompressedTensorsScaleConvert(ConversionOps): | ||
| """Convert compressed-tensors `weight_scale` to `weight_scale_inv`. | ||
|
|
||
| In compressed-tensors, `weight_scale` is the dequantization multiplier: | ||
| bf16_weight = fp8_weight * weight_scale | ||
|
|
||
| In our CompressedTensorsFP8Linear, `weight_scale_inv` has the same semantics (it's | ||
| multiplied with the FP8 weight to get the dequantized value), so no inversion is needed. | ||
| The conversion also reshapes the scale: scalar → (1, 1), 1D (N,) → (N, 1). | ||
| """ | ||
|
|
There was a problem hiding this comment.
same here, we can keep the same name no ?
| class CompressedTensorsFp8Dequantize(ConversionOps): | ||
| """Dequantize compressed-tensors FP8 weights back to BF16. | ||
|
|
||
| Folds the per-channel / per-tensor ``weight_scale`` into the FP8 weight, | ||
| producing a BF16 tensor. Prepended to a converter chain for layers that | ||
| cannot stay in FP8 (e.g. merged MoE experts, which are not ``nn.Linear``): | ||
| it pairs each weight with its sibling scale *by index* and preserves the | ||
| per-expert list structure so the downstream merge / concat ops still see | ||
| one tensor per expert. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Do we really need this ? Like it is not really useful to online quantize a model to finally save it in bf16 no ?
With compressed tensors, it will dequantize the model in any case no if you specify run_compressed=False no for a quantized model.
There was a problem hiding this comment.
also can you explain the moe bit, i didn't fully understand
|
Hi @SunMarc, thanks for the review! Addressed everything; kept the MoE dequant-before-merge on purpose (explained below). Done
Side effect: The MoE dequant (comments 10 & 11) MoE checkpoints store per-expert weights+scales, but transformers merges the experts ( Known limitation: since merged experts land in BF16, this PR doesn't save memory on MoE expert weights (the bulk of params) — only dense models and the attention/router part of MoE benefit. Not fundamental: experts could stay FP8 with (1) a 3D FP8 expert param, (2) a stacked per-expert scale tensor, (3) a grouped scaled-mm in the MoE forward. That's a bigger change, so I'd suggest a follow-up and keeping this dequant path as the correct baseline. Happy to adjust naming or split differently — let me know! |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
SunMarc
left a comment
There was a problem hiding this comment.
Left a couple of comments ! Try to fix those but I can take over the PR at some point if needed. I really want this is be done correctly
| run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to | ||
| emulate compressed model execution if True, otherwise use default submodule |
There was a problem hiding this comment.
run_compressed should be deprecated and we should use dequantize instead
dequantize = not run_compressed
we shouldn't use it anywhere else
There was a problem hiding this comment.
Done. run_compressed is deprecated (warns and maps to dequantize); dequantize is now the single source of truth for all internal logic.
|
|
||
| def get_quantize_ops(self): | ||
| # Online FP8 quantization is not supported (use finegrained-fp8 instead), | ||
| # so there is never an on-the-fly quantization op to apply. | ||
| return None | ||
|
|
| from ..integrations.compressed_tensors_fp8 import CompressedTensorsFp8Dequantize | ||
|
|
||
| updated: list = [] | ||
| for conv in weight_conversions: | ||
| # Only WeightConverter has ``.operations`` to extend; WeightRenaming | ||
| # rules just pass through untouched. | ||
| if not isinstance(conv, WeightConverter): | ||
| updated.append(conv) | ||
| continue | ||
|
|
||
| weight_sources = [p for p in conv.source_patterns if p.endswith(".weight")] | ||
| if weight_sources: | ||
| anchored_weight = [p + "$" for p in weight_sources] | ||
| scale_sources = [p[: -len(".weight")] + ".weight_scale$" for p in weight_sources] | ||
| other = [p for p in conv.source_patterns if not p.endswith(".weight")] | ||
| new_sources = anchored_weight + scale_sources + other | ||
| new_ops = [CompressedTensorsFp8Dequantize()] + list(conv.operations) | ||
| conv = WeightConverter( | ||
| source_patterns=new_sources, | ||
| target_patterns=conv._original_target_patterns, | ||
| operations=new_ops, | ||
| ) | ||
| updated.append(conv) |
There was a problem hiding this comment.
This means that when we do fusion, we will have to dequantize this particular module ? If this is the case, maybe we should return a warning no ?
There was a problem hiding this comment.
Done, added a logger.warning_once when merged (e.g. MoE expert) weights fall back to dequantized BF16.
| def is_trainable(self): | ||
| # The FP8 kernel path is inference-only. With `dequantize=True` we don't enter | ||
| # it (`is_fp8` is False) and fall back to the regular compressed-tensors route. | ||
| if self.is_fp8: |
There was a problem hiding this comment.
can we think of a better name compared to is_fp8 ?
There was a problem hiding this comment.
Done. The quantizer flag is now use_fp8_kernel (is_fp8 on the config still means "the checkpoint is FP8 type").
| if quantization_config.is_fp8 and not quantization_config.dequantize: | ||
| self.is_fp8 = True |
There was a problem hiding this comment.
i think it might be better to put the quantization_config.dequantize in the is_fp8 method directly no ?
There was a problem hiding this comment.
Moved the dequantize check off the quantizer and into the config, as you asked. I kept is_fp8 meaning purely "the checkpoint is FP8 type" (so it doesn't flip to False for an FP8 checkpoint that's being dequantized) and added a dedicated property CompressedTensorsConfig.use_fp8_kernel = is_fp8 and not dequantize. The quantizer now just reads use_fp8_kernel. This also covers your "better name than is_fp8" comment. Happy to fold it straight into is_fp8 instead if you prefer.
| def _scale_pattern_for(weight_pattern: str) -> str: | ||
| # Strip the optional ``$`` regex anchor so we can match the underlying name. | ||
| anchored = weight_pattern.endswith("$") | ||
| base = weight_pattern[:-1] if anchored else weight_pattern | ||
| if base.endswith(".weight"): | ||
| scale = base[: -len(".weight")] + ".weight_scale" | ||
| elif base == "weight": | ||
| scale = "weight_scale" | ||
| else: | ||
| scale = base + "_scale" | ||
| return scale + "$" if anchored else scale | ||
|
|
There was a problem hiding this comment.
do we really need all these cases ?
There was a problem hiding this comment.
Done, reduced to the single .weight -> .weight_scale case.
| class CompressedTensorsFp8Dequantize(ConversionOps): | ||
| """Dequantize compressed-tensors FP8 weights back to BF16. | ||
|
|
||
| Folds the per-channel / per-tensor ``weight_scale`` into the FP8 weight, | ||
| producing a BF16 tensor. Prepended to a converter chain for layers that | ||
| cannot stay in FP8 (e.g. merged MoE experts, which are not ``nn.Linear``): | ||
| it pairs each weight with its sibling scale *by index* and preserves the | ||
| per-expert list structure so the downstream merge / concat ops still see | ||
| one tensor per expert. | ||
| """ |
There was a problem hiding this comment.
in general, can you try to simplify this a bit
| if not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())): | ||
| self.skipTest("FP8 kernel path requires GPU or XPU") |
There was a problem hiding this comment.
Done, replaced the manual skipTest with @require_torch_accelerator.
| # check perplexity | ||
| perplexity = torch.exp(outputs.loss) | ||
| self.assertLessEqual(perplexity, expected_perplexity) | ||
|
|
There was a problem hiding this comment.
please add more integration tests. test the dequantize path. with your changes, maybe some fp8 tests don't behave the same also. Add a test also when saving the fp8 model, it should work again when reloading maybe. feels like the changes we did are quite minor.
There was a problem hiding this comment.
Done. Added test_tinyllama_fp8_dequantize (CPU dequantize path) and test_tinyllama_fp8_save_reload.
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
[For maintainers] Suggested jobs to run (before merge) run-slow: compressed_tensors_integration |
CI recapDashboard: View test results in Grafana |
|
Hi @SunMarc . I've fixed all your comments. Please check it. Thanks! |
What does this PR do?
This PR adds native FP8 matmul kernel support for compressed-tensors FP8 quantized models in transformers. Previously, compressed-tensors FP8 models were loaded via the
compressed-tensorslibrary and dequantized back to FP16/BF16 for inference. With this change, FP8 weights are kept in FP8 format and inference uses hardware-accelerated FP8 matmul kernels (torch._scaled_mmon XPU,fbgemm.f8f8bf16_rowwiseon CUDA).Key changes:
New file:
src/transformers/integrations/compressed_tensors_fp8.pyCTFP8Linear: FP8 linear layer that stores weights in FP8 and uses row-wise FP8 matmul kernels. Activations are dynamically quantized per-row viaquantize_fp8_per_row.CompressedTensorsScaleConvert,CompressedTensorsFp8Dequantize) to handle the checkpoint format conversion (e.g.weight_scale→weight_scale_inv).CTFP8PerRowQuantize: Online quantization support — quantize BF16 weights to FP8 per-row on-the-fly during model loading.Modified:
src/transformers/quantizers/quantizer_compressed_tensors.pyCompressedTensorsHfQuantizernow detects FP8 quantization configs (floattype,num_bits=8) and automatically routes to the FP8 kernel path when GPU/XPU is available. Falls back to the default compressed-tensors dequantize path on CPU.get_weight_conversions()andget_quantize_ops()to support both pre-quantized loading and online quantization.Modified:
src/transformers/quantizers/auto.pySupported models
CompressedTensorsConfigwith FP8 quantization scheme.Usage
Pre-quantized model (no config needed)
Online quantization
Devices
torch._scaled_mmfbgemm.f8f8bf16_rowwise@sywangyi