Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def determine_available_memory(self) -> int:
with disable_compilation(model):
return super().determine_available_memory()

def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
if (
quant_config["quant_cfg"]
or quant_config["kv_quant_cfg"]
or quant_config["modelopt_state_path"]
):
_fakequant_run_prolog_worker(self)
super().compile_or_warm_up_model()
return super().compile_or_warm_up_model()
284 changes: 234 additions & 50 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""Export HuggingFace model to vLLM fakequant checkpoint."""

import logging
from collections.abc import Mapping
from pathlib import Path

import torch
Expand All @@ -26,6 +28,8 @@
from modelopt.torch.quantization.utils import get_quantizer_state_dict
from modelopt.torch.utils import get_unwrapped_name

logger = logging.getLogger(__name__)

__all__ = ["export_hf_vllm_fq_checkpoint"]


Expand All @@ -38,6 +42,135 @@ def disable_rotate(quantizer: TensorQuantizer):
return False


def _materialize_offloaded_weights(
model: nn.Module,
state_dict: dict[str, torch.Tensor],
meta_keys: list[str],
) -> None:
"""Replace meta tensors in state_dict with actual data from accelerate offload hooks.

When a model is loaded with ``device_map="auto"`` and some layers are offloaded
to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those
layers. This function walks the model's accelerate hooks to retrieve the actual
weight data and updates state_dict in-place.
"""
hook_entries: list[tuple[str, str, Mapping[str, torch.Tensor]]] = []

def _weights_map_from_hook(hook_obj):
"""Best-effort extraction of an accelerate weights_map from a hook object."""
if hasattr(hook_obj, "weights_map") and hook_obj.weights_map is not None:
return hook_obj.weights_map
if hasattr(hook_obj, "hooks"):
for h in hook_obj.hooks:
if hasattr(h, "weights_map") and h.weights_map is not None:
return h.weights_map
return None

try:
# Reuse accelerate plugin hook resolution instead of duplicating traversal logic.
from modelopt.torch.quantization.plugins.accelerate import _get_cpu_offload_hook
except ImportError:
_get_cpu_offload_hook = None

for name, module in model.named_modules():
hook = getattr(module, "_hf_hook", None)
if hook is None or _get_cpu_offload_hook is None:
continue
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: When _get_cpu_offload_hook is None (ImportError), the function will skip all modules with _hf_hook, meaning no weights get materialized. This will then trigger the RuntimeError downstream for any non-quantizer meta keys, which is fine. But a more informative early error or warning when accelerate plugin import fails would help debugging:

if _get_cpu_offload_hook is None and meta_keys:
    logger.warning(
        "Could not import accelerate plugin (_get_cpu_offload_hook). "
        "Cannot materialize %d offloaded weights.", len(meta_keys)
    )
    return


align_hook = None
if _get_cpu_offload_hook is not None:
try:
align_hook = _get_cpu_offload_hook(hook)
except AssertionError:
# Some accelerate hook variants do not expose a plain "weight" key.
# Fall back to generic weights_map extraction for export-time readout.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The guard if _get_cpu_offload_hook is None is checked at the top of the loop with continue, so the inner if _get_cpu_offload_hook is not None: on line 88 is always True and can be removed.

# This is dead code — _get_cpu_offload_hook is guaranteed non-None here
if _get_cpu_offload_hook is not None:

align_hook = None

wmap = align_hook.weights_map if align_hook is not None else _weights_map_from_hook(hook)
if wmap is None:
continue

if hasattr(wmap, "dataset"):
weight_prefix = wmap.prefix
actual_sd = wmap.dataset.state_dict
else:
weight_prefix = ""
actual_sd = wmap

module_prefix = f"{name}." if name else ""
hook_entries.append((module_prefix, weight_prefix, actual_sd))

# Match most-specific module prefixes first to avoid ambiguous parent-prefix hits.
hook_entries.sort(key=lambda x: len(x[0]), reverse=True)

materialized = 0
for key in meta_keys:
for module_prefix, weight_prefix, actual_sd in hook_entries:
if not key.startswith(module_prefix):
continue
local_key = key[len(module_prefix) :]
lookup_key = weight_prefix + local_key
if lookup_key in actual_sd:
state_dict[key] = actual_sd[lookup_key].detach().clone()
materialized += 1
break
else:
logger.warning("Could not materialize meta tensor for key: %s", key)

logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys))


def _save_clean_checkpoint(
model: nn.Module,
clean_sd: dict[str, torch.Tensor],
export_dir: Path,
) -> None:
"""Save clean weights + config directly, bypassing model.save_pretrained().

For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)``
ignores the provided state_dict and saves from internal state, leaking
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavioral change for ALL models: _save_clean_checkpoint now replaces model.save_pretrained() unconditionally — not just for offloaded models. The old save_pretrained() also saved generation_config.json, tokenizer files (if applicable), and ran any save hooks. The new code only saves safetensors + config.json.

For the vLLM FakeQuant use case this is probably fine (vLLM doesn't need generation_config.json from the export dir). But it's worth documenting this behavioral change, or alternatively only using _save_clean_checkpoint when offloading is detected (i.e., when meta_keys is non-empty) and falling back to model.save_pretrained() otherwise.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_save_clean_checkpoint is now used only when offloaded/meta tensors are detected.

quantizer keys. This function saves ``clean_sd`` directly via safetensors
API, guaranteeing only the intended keys are written.
"""
import json

from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file

export_dir.mkdir(parents=True, exist_ok=True)

state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB")
for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
# Keep peak memory bounded: move and clone one shard at a time.
# Cloning also breaks shared storage, which safetensors rejects.
shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys}
save_file(shard, str(export_dir / shard_file))
logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding "model_type" to the metadata in the index JSON. The HuggingFace convention for model.safetensors.index.json typically includes {"total_size": ...} in metadata. state_dict_split.metadata may already have this, but worth verifying — some tools expect total_size in the index metadata.


if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
(export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2))

if hasattr(model, "config"):
config = model.config.to_dict()
config_path = export_dir / "config.json"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: auto_map is not stripped from config.json. The PR description explicitly lists 'strips auto_map from config.json' as one of the four fixes, noting that auto_map references custom Python files not present in the export directory, causing OSError in vLLM. However, model.config.to_dict() will preserve auto_map if the model config has it set. You need to explicitly remove it:

config = model.config.to_dict()
config.pop("auto_map", None)  # Custom code files are not in the export dir

config_path.write_text(json.dumps(config, indent=2) + "\n")
logger.info("Saved config.json")

generation_config = getattr(model, "generation_config", None)
if generation_config is not None:
generation_config.save_pretrained(export_dir)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tokenizer save. model.save_pretrained() typically also saves the tokenizer (if the model has one). This bypass only saves weights, config, and generation_config. Consider whether the tokenizer should be saved here as well, or document why it's intentionally omitted. If the user expects a complete self-contained checkpoint, the missing tokenizer could be a problem.

Also, some models have additional files saved by save_pretrained (e.g., preprocessor_config.json for multimodal models). This manual approach may miss them.

logger.info(
"Checkpoint saved: %d weights in %d shard(s)",
len(clean_sd),
len(state_dict_split.filename_to_tensors),
)


def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
Expand All @@ -62,6 +195,31 @@ def export_hf_vllm_fq_checkpoint(
# parameters are never modified. Apply each weight quantizer's fake-quant
# to the corresponding weight tensor in the copy.
state_dict = model.state_dict()

# Handle accelerate-offloaded models: state_dict() returns meta tensors
# for CPU/disk-offloaded layers. Materialize them from the offload hooks.
meta_keys = [k for k, v in state_dict.items() if v.is_meta]
has_offloaded_weights = bool(meta_keys)
if meta_keys:
logger.info(
"Found %d meta tensors in state_dict (accelerate offloading). "
"Materializing from offload hooks...",
len(meta_keys),
)
_materialize_offloaded_weights(model, state_dict, meta_keys)
unresolved_meta_keys = [
k
for k, v in state_dict.items()
if v.is_meta and "quantizer" not in k and "quant" not in k
]
if unresolved_meta_keys:
shown = ", ".join(unresolved_meta_keys[:10])
suffix = " ..." if len(unresolved_meta_keys) > 10 else ""
raise RuntimeError(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unresolved meta key check filters out keys containing "quantizer" or "quant". This heuristic could accidentally suppress real errors — a weight key like "dequant_proj.weight" or "quantize_proj.weight" (if such names ever exist) would be silently ignored. Consider being more precise, e.g., checking if the key corresponds to a TensorQuantizer module rather than using substring matching.

"Failed to materialize offloaded tensors before fake-quant folding / "
f"_save_clean_checkpoint: {shown}{suffix}"
)

fakequant_weights = set()
input_quantizers_folded_pqs = (
set()
Expand All @@ -86,6 +244,26 @@ def export_hf_vllm_fq_checkpoint(
)
if sd_key in state_dict:
w = state_dict[sd_key]
# Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA.
# Offloaded weights materialized to CPU need a GPU hop.
if not w.is_cuda:
# Find a CUDA device from quantizer/module tensors.
cuda_dev = None
for t in list(quantizer.parameters()) + list(quantizer.buffers()):
if t.is_cuda:
cuda_dev = t.device
break
if cuda_dev is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential silent failure: If no CUDA device is found (e.g., all quantizer buffers/params and module params happen to be on CPU/meta), cuda_dev remains None and w stays on CPU. The subsequent quantizer(w.float()) call will likely fail with a cryptic CUDA error deep in the kernel. Consider raising a clear error:

if cuda_dev is None:
    raise RuntimeError(
        f"Cannot find CUDA device for quantizer kernel on offloaded weight '{sd_key}'. "
        "Ensure at least one quantizer buffer or module parameter is on CUDA."
    )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added explicit error

for t in list(module.parameters()) + list(module.buffers()):
if t.is_cuda:
cuda_dev = t.device
break
if cuda_dev is None:
raise RuntimeError(
"Cannot find CUDA device for quantizer kernel on offloaded weight "
f"'{sd_key}'. Ensure at least one quantizer/module tensor is on CUDA."
)
w = w.to(cuda_dev)
w_quant = quantizer(w.float()).to(w.dtype).cpu()
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
# Only valid when input_quantizer does NOT fake-quant activations. If it does
Expand Down Expand Up @@ -117,53 +295,59 @@ def export_hf_vllm_fq_checkpoint(
# Rotation is also cleared: the weight was already folded with rotation applied,
# so if fold_weight is called on reload it must not re-rotate the exported weight.
wqs_to_restore = []
for _, module in model.named_modules():
if isinstance(module, QuantModule):
for attr_name, quantizer in module.named_children():
if (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.is_enabled
):
quantizer.disable()
orig_rotate = quantizer._rotate
if quantizer.rotate_is_enabled:
quantizer._rotate = disable_rotate(quantizer)
wqs_to_restore.append((quantizer, orig_rotate))

quantizer_state_dict = get_quantizer_state_dict(model)
for key in list(quantizer_state_dict):
if key.endswith("weight_quantizer"):
# Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
quantizer_state_dict.pop(key)
elif key in input_quantizers_folded_pqs:
# pre_quant_scale was folded into the weight; keep the buffer for strict load but
# save identity so activations are not scaled twice.
qstate_val = quantizer_state_dict[key]
if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val:
quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like(
qstate_val["_pre_quant_scale"]
)
modelopt_state = mto.modelopt_state(model)
# ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
# ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded).
qstate = quantizer_state(model)
for key in list(qstate):
if key.endswith("weight_quantizer") and qstate[key].get("_disabled"):
qstate.pop(key)

for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
if mode_str == "quantize" and "metadata" in m_state:
m_state["metadata"]["quantizer_state"] = qstate
break

# Per-quantizer tensor dict loaded alongside metadata on reload.
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")

# Step 3: Save HF weights using the pre-built folded state dict.
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)

for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
try:
for _, module in model.named_modules():
if isinstance(module, QuantModule):
for attr_name, quantizer in module.named_children():
if (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.is_enabled
):
quantizer.disable()
orig_rotate = quantizer._rotate
if quantizer.rotate_is_enabled:
quantizer._rotate = disable_rotate(quantizer)
wqs_to_restore.append((quantizer, orig_rotate))

quantizer_state_dict = get_quantizer_state_dict(model)
for key in list(quantizer_state_dict):
if key.endswith("weight_quantizer"):
# Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
quantizer_state_dict.pop(key)
elif key in input_quantizers_folded_pqs:
# pre_quant_scale was folded into the weight; keep the buffer for strict load but
# save identity so activations are not scaled twice.
qstate_val = quantizer_state_dict[key]
if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val:
quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like(
qstate_val["_pre_quant_scale"]
)
modelopt_state = mto.modelopt_state(model)
# ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
# ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded).
qstate = quantizer_state(model)
for key in list(qstate):
if key.endswith("weight_quantizer") and qstate[key].get("_disabled"):
qstate.pop(key)

for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
if mode_str == "quantize" and "metadata" in m_state:
m_state["metadata"]["quantizer_state"] = qstate
break

# Per-quantizer tensor dict loaded alongside metadata on reload.
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")

# Step 3: Save HF weights.
# Accelerate-offloaded models may ignore state_dict= in save_pretrained()
# and leak quantizer keys, so use manual save only in that case.
if has_offloaded_weights:
_save_clean_checkpoint(model, clean_sd, export_dir)
else:
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
finally:
for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
Loading
Loading