diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index ec2b1f4033..1fddecd6ae 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -134,11 +134,11 @@ def determine_available_memory(self) -> int: with disable_compilation(model): return super().determine_available_memory() - def compile_or_warm_up_model(self) -> None: + def compile_or_warm_up_model(self) -> float: if ( quant_config["quant_cfg"] or quant_config["kv_quant_cfg"] or quant_config["modelopt_state_path"] ): _fakequant_run_prolog_worker(self) - super().compile_or_warm_up_model() + return super().compile_or_warm_up_model() diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 1908354a0a..2d6717ecbb 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -14,6 +14,8 @@ # limitations under the License. """Export HuggingFace model to vLLM fakequant checkpoint.""" +import logging +from collections.abc import Mapping from pathlib import Path import torch @@ -26,6 +28,8 @@ from modelopt.torch.quantization.utils import get_quantizer_state_dict from modelopt.torch.utils import get_unwrapped_name +logger = logging.getLogger(__name__) + __all__ = ["export_hf_vllm_fq_checkpoint"] @@ -38,6 +42,135 @@ def disable_rotate(quantizer: TensorQuantizer): return False +def _materialize_offloaded_weights( + model: nn.Module, + state_dict: dict[str, torch.Tensor], + meta_keys: list[str], +) -> None: + """Replace meta tensors in state_dict with actual data from accelerate offload hooks. + + When a model is loaded with ``device_map="auto"`` and some layers are offloaded + to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those + layers. This function walks the model's accelerate hooks to retrieve the actual + weight data and updates state_dict in-place. + """ + hook_entries: list[tuple[str, str, Mapping[str, torch.Tensor]]] = [] + + def _weights_map_from_hook(hook_obj): + """Best-effort extraction of an accelerate weights_map from a hook object.""" + if hasattr(hook_obj, "weights_map") and hook_obj.weights_map is not None: + return hook_obj.weights_map + if hasattr(hook_obj, "hooks"): + for h in hook_obj.hooks: + if hasattr(h, "weights_map") and h.weights_map is not None: + return h.weights_map + return None + + try: + # Reuse accelerate plugin hook resolution instead of duplicating traversal logic. + from modelopt.torch.quantization.plugins.accelerate import _get_cpu_offload_hook + except ImportError: + _get_cpu_offload_hook = None + + for name, module in model.named_modules(): + hook = getattr(module, "_hf_hook", None) + if hook is None or _get_cpu_offload_hook is None: + continue + + align_hook = None + if _get_cpu_offload_hook is not None: + try: + align_hook = _get_cpu_offload_hook(hook) + except AssertionError: + # Some accelerate hook variants do not expose a plain "weight" key. + # Fall back to generic weights_map extraction for export-time readout. + align_hook = None + + wmap = align_hook.weights_map if align_hook is not None else _weights_map_from_hook(hook) + if wmap is None: + continue + + if hasattr(wmap, "dataset"): + weight_prefix = wmap.prefix + actual_sd = wmap.dataset.state_dict + else: + weight_prefix = "" + actual_sd = wmap + + module_prefix = f"{name}." if name else "" + hook_entries.append((module_prefix, weight_prefix, actual_sd)) + + # Match most-specific module prefixes first to avoid ambiguous parent-prefix hits. + hook_entries.sort(key=lambda x: len(x[0]), reverse=True) + + materialized = 0 + for key in meta_keys: + for module_prefix, weight_prefix, actual_sd in hook_entries: + if not key.startswith(module_prefix): + continue + local_key = key[len(module_prefix) :] + lookup_key = weight_prefix + local_key + if lookup_key in actual_sd: + state_dict[key] = actual_sd[lookup_key].detach().clone() + materialized += 1 + break + else: + logger.warning("Could not materialize meta tensor for key: %s", key) + + logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys)) + + +def _save_clean_checkpoint( + model: nn.Module, + clean_sd: dict[str, torch.Tensor], + export_dir: Path, +) -> None: + """Save clean weights + config directly, bypassing model.save_pretrained(). + + For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)`` + ignores the provided state_dict and saves from internal state, leaking + quantizer keys. This function saves ``clean_sd`` directly via safetensors + API, guaranteeing only the intended keys are written. + """ + import json + + from huggingface_hub import split_torch_state_dict_into_shards + from safetensors.torch import save_file + + export_dir.mkdir(parents=True, exist_ok=True) + + state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB") + for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items(): + # Keep peak memory bounded: move and clone one shard at a time. + # Cloning also breaks shared storage, which safetensors rejects. + shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys} + save_file(shard, str(export_dir / shard_file)) + logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard)) + + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + (export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2)) + + if hasattr(model, "config"): + config = model.config.to_dict() + config_path = export_dir / "config.json" + config_path.write_text(json.dumps(config, indent=2) + "\n") + logger.info("Saved config.json") + + generation_config = getattr(model, "generation_config", None) + if generation_config is not None: + generation_config.save_pretrained(export_dir) + + logger.info( + "Checkpoint saved: %d weights in %d shard(s)", + len(clean_sd), + len(state_dict_split.filename_to_tensors), + ) + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, @@ -62,6 +195,31 @@ def export_hf_vllm_fq_checkpoint( # parameters are never modified. Apply each weight quantizer's fake-quant # to the corresponding weight tensor in the copy. state_dict = model.state_dict() + + # Handle accelerate-offloaded models: state_dict() returns meta tensors + # for CPU/disk-offloaded layers. Materialize them from the offload hooks. + meta_keys = [k for k, v in state_dict.items() if v.is_meta] + has_offloaded_weights = bool(meta_keys) + if meta_keys: + logger.info( + "Found %d meta tensors in state_dict (accelerate offloading). " + "Materializing from offload hooks...", + len(meta_keys), + ) + _materialize_offloaded_weights(model, state_dict, meta_keys) + unresolved_meta_keys = [ + k + for k, v in state_dict.items() + if v.is_meta and "quantizer" not in k and "quant" not in k + ] + if unresolved_meta_keys: + shown = ", ".join(unresolved_meta_keys[:10]) + suffix = " ..." if len(unresolved_meta_keys) > 10 else "" + raise RuntimeError( + "Failed to materialize offloaded tensors before fake-quant folding / " + f"_save_clean_checkpoint: {shown}{suffix}" + ) + fakequant_weights = set() input_quantizers_folded_pqs = ( set() @@ -86,6 +244,26 @@ def export_hf_vllm_fq_checkpoint( ) if sd_key in state_dict: w = state_dict[sd_key] + # Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA. + # Offloaded weights materialized to CPU need a GPU hop. + if not w.is_cuda: + # Find a CUDA device from quantizer/module tensors. + cuda_dev = None + for t in list(quantizer.parameters()) + list(quantizer.buffers()): + if t.is_cuda: + cuda_dev = t.device + break + if cuda_dev is None: + for t in list(module.parameters()) + list(module.buffers()): + if t.is_cuda: + cuda_dev = t.device + break + if cuda_dev is None: + raise RuntimeError( + "Cannot find CUDA device for quantizer kernel on offloaded weight " + f"'{sd_key}'. Ensure at least one quantizer/module tensor is on CUDA." + ) + w = w.to(cuda_dev) w_quant = quantizer(w.float()).to(w.dtype).cpu() # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) # Only valid when input_quantizer does NOT fake-quant activations. If it does @@ -117,53 +295,59 @@ def export_hf_vllm_fq_checkpoint( # Rotation is also cleared: the weight was already folded with rotation applied, # so if fold_weight is called on reload it must not re-rotate the exported weight. wqs_to_restore = [] - for _, module in model.named_modules(): - if isinstance(module, QuantModule): - for attr_name, quantizer in module.named_children(): - if ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.is_enabled - ): - quantizer.disable() - orig_rotate = quantizer._rotate - if quantizer.rotate_is_enabled: - quantizer._rotate = disable_rotate(quantizer) - wqs_to_restore.append((quantizer, orig_rotate)) - - quantizer_state_dict = get_quantizer_state_dict(model) - for key in list(quantizer_state_dict): - if key.endswith("weight_quantizer"): - # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. - quantizer_state_dict.pop(key) - elif key in input_quantizers_folded_pqs: - # pre_quant_scale was folded into the weight; keep the buffer for strict load but - # save identity so activations are not scaled twice. - qstate_val = quantizer_state_dict[key] - if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: - quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( - qstate_val["_pre_quant_scale"] - ) - modelopt_state = mto.modelopt_state(model) - # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild - # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). - qstate = quantizer_state(model) - for key in list(qstate): - if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): - qstate.pop(key) - - for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): - if mode_str == "quantize" and "metadata" in m_state: - m_state["metadata"]["quantizer_state"] = qstate - break - - # Per-quantizer tensor dict loaded alongside metadata on reload. - modelopt_state["modelopt_state_weights"] = quantizer_state_dict - torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") - - # Step 3: Save HF weights using the pre-built folded state dict. - model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) - - for wq, orig_rotate in wqs_to_restore: - wq.enable() - wq._rotate = orig_rotate + try: + for _, module in model.named_modules(): + if isinstance(module, QuantModule): + for attr_name, quantizer in module.named_children(): + if ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.is_enabled + ): + quantizer.disable() + orig_rotate = quantizer._rotate + if quantizer.rotate_is_enabled: + quantizer._rotate = disable_rotate(quantizer) + wqs_to_restore.append((quantizer, orig_rotate)) + + quantizer_state_dict = get_quantizer_state_dict(model) + for key in list(quantizer_state_dict): + if key.endswith("weight_quantizer"): + # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. + quantizer_state_dict.pop(key) + elif key in input_quantizers_folded_pqs: + # pre_quant_scale was folded into the weight; keep the buffer for strict load but + # save identity so activations are not scaled twice. + qstate_val = quantizer_state_dict[key] + if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: + quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( + qstate_val["_pre_quant_scale"] + ) + modelopt_state = mto.modelopt_state(model) + # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild + # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). + qstate = quantizer_state(model) + for key in list(qstate): + if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): + qstate.pop(key) + + for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): + if mode_str == "quantize" and "metadata" in m_state: + m_state["metadata"]["quantizer_state"] = qstate + break + + # Per-quantizer tensor dict loaded alongside metadata on reload. + modelopt_state["modelopt_state_weights"] = quantizer_state_dict + torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") + + # Step 3: Save HF weights. + # Accelerate-offloaded models may ignore state_dict= in save_pretrained() + # and leak quantizer keys, so use manual save only in that case. + if has_offloaded_weights: + _save_clean_checkpoint(model, clean_sd, export_dir) + else: + model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) + finally: + for wq, orig_rotate in wqs_to_restore: + wq.enable() + wq._rotate = orig_rotate diff --git a/tests/unit/torch/export/test_vllm_fakequant_hf_export_utils.py b/tests/unit/torch/export/test_vllm_fakequant_hf_export_utils.py new file mode 100644 index 0000000000..2e2604e412 --- /dev/null +++ b/tests/unit/torch/export/test_vllm_fakequant_hf_export_utils.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import types + +import pytest +import torch +import torch.nn as nn +from modelopt.torch.export.plugins import vllm_fakequant_hf as vllm_fq + + +class _MinimalModel(nn.Module): + def __init__(self, meta_weight: bool = False): + super().__init__() + device = "meta" if meta_weight else "cpu" + self.weight = nn.Parameter(torch.ones(4, device=device)) + self.save_calls = [] + + def save_pretrained(self, export_dir, **kwargs): + self.save_calls.append((export_dir, kwargs)) + + +class _DummyHook: + def __init__(self, weights_map): + self.weights_map = weights_map + + +def _patch_minimal_modelopt_state(monkeypatch): + monkeypatch.setattr(vllm_fq, "get_quantizer_state_dict", lambda _model: {}) + monkeypatch.setattr(vllm_fq, "quantizer_state", lambda _model: {}) + monkeypatch.setattr(vllm_fq.mto, "modelopt_state", lambda _model: {"modelopt_state_dict": []}) + monkeypatch.setattr(vllm_fq.torch, "save", lambda _obj, _path: None) + + +def test_materialize_uses_longest_module_prefix(monkeypatch): + class _NestedModel(nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Module() + self.a.b = nn.Module() + self.a._hf_hook = _DummyHook({"b.weight": torch.tensor([1.0])}) + self.a.b._hf_hook = _DummyHook({"weight": torch.tensor([2.0])}) + + model = _NestedModel() + state_dict = {"a.b.weight": torch.empty(1, device="meta")} + + fake_accel = types.ModuleType("modelopt.torch.quantization.plugins.accelerate") + fake_accel._get_cpu_offload_hook = lambda hook: hook + monkeypatch.setitem(sys.modules, "modelopt.torch.quantization.plugins.accelerate", fake_accel) + + vllm_fq._materialize_offloaded_weights(model, state_dict, ["a.b.weight"]) + assert torch.allclose(state_dict["a.b.weight"], torch.tensor([2.0])) + + +def test_export_raises_if_non_quant_meta_tensors_remain(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + model = _MinimalModel(meta_weight=True) + + monkeypatch.setattr(vllm_fq, "_materialize_offloaded_weights", lambda *_args, **_kwargs: None) + + with ( + torch.inference_mode(), + pytest.raises(RuntimeError, match="Failed to materialize offloaded tensors") as exc, + ): + vllm_fq.export_hf_vllm_fq_checkpoint(model, export_dir=tmp_path / "export_meta_fail") + assert "_save_clean_checkpoint" in str(exc.value) + + +def test_export_uses_model_save_pretrained_when_not_offloaded(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + model = _MinimalModel(meta_weight=False) + called = {"clean": 0} + + def _save_clean_checkpoint(*_args, **_kwargs): + called["clean"] += 1 + + monkeypatch.setattr(vllm_fq, "_save_clean_checkpoint", _save_clean_checkpoint) + vllm_fq.export_hf_vllm_fq_checkpoint(model, export_dir=tmp_path / "export_non_offloaded") + + assert called["clean"] == 0 + assert len(model.save_calls) == 1 + assert model.save_calls[0][1]["save_modelopt_state"] is False + assert "state_dict" in model.save_calls[0][1] + + +def test_export_uses_clean_checkpoint_when_offloaded(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + model = _MinimalModel(meta_weight=True) + called = {"clean": 0} + + def _materialize(_model, state_dict, _meta_keys): + state_dict["weight"] = torch.ones(4) + + def _save_clean_checkpoint(*_args, **_kwargs): + called["clean"] += 1 + + def _unexpected_save_pretrained(*_args, **_kwargs): + raise AssertionError("model.save_pretrained should not be called for offloaded export") + + monkeypatch.setattr(vllm_fq, "_materialize_offloaded_weights", _materialize) + monkeypatch.setattr(vllm_fq, "_save_clean_checkpoint", _save_clean_checkpoint) + model.save_pretrained = _unexpected_save_pretrained + + vllm_fq.export_hf_vllm_fq_checkpoint(model, export_dir=tmp_path / "export_offloaded") + assert called["clean"] == 1 + + +def test_export_raises_when_cuda_device_cannot_be_found(monkeypatch, tmp_path): + _patch_minimal_modelopt_state(monkeypatch) + + class _DummyTensorQuantizer(nn.Module): + def __init__(self): + super().__init__() + self.fake_quant = True + self.is_enabled = True + self.rotate_is_enabled = False + self._rotate = False + + def disable(self): + self.is_enabled = False + + def enable(self): + self.is_enabled = True + + def forward(self, x): + return x + + class _DummyQuantModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(2, 2)) + self.weight_quantizer = _DummyTensorQuantizer() + + class _DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.block = _DummyQuantModule() + + def save_pretrained(self, _export_dir, **_kwargs): + return None + + monkeypatch.setattr(vllm_fq, "QuantModule", _DummyQuantModule) + monkeypatch.setattr(vllm_fq, "TensorQuantizer", _DummyTensorQuantizer) + + with torch.inference_mode(), pytest.raises( + RuntimeError, match="Cannot find CUDA device for quantizer kernel" + ): + vllm_fq.export_hf_vllm_fq_checkpoint( + _DummyModel(), export_dir=tmp_path / "export_cuda_missing" + )