-
Notifications
You must be signed in to change notification settings - Fork 353
fix: handle accelerate CPU-offloaded models in FakeQuant export #1194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3895d21
b85c4e0
8c189ac
46776ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,8 @@ | |
| # limitations under the License. | ||
| """Export HuggingFace model to vLLM fakequant checkpoint.""" | ||
|
|
||
| import logging | ||
| from collections.abc import Mapping | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
@@ -26,6 +28,8 @@ | |
| from modelopt.torch.quantization.utils import get_quantizer_state_dict | ||
| from modelopt.torch.utils import get_unwrapped_name | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| __all__ = ["export_hf_vllm_fq_checkpoint"] | ||
|
|
||
|
|
||
|
|
@@ -38,6 +42,135 @@ def disable_rotate(quantizer: TensorQuantizer): | |
| return False | ||
|
|
||
|
|
||
| def _materialize_offloaded_weights( | ||
| model: nn.Module, | ||
| state_dict: dict[str, torch.Tensor], | ||
| meta_keys: list[str], | ||
| ) -> None: | ||
| """Replace meta tensors in state_dict with actual data from accelerate offload hooks. | ||
|
|
||
| When a model is loaded with ``device_map="auto"`` and some layers are offloaded | ||
| to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those | ||
| layers. This function walks the model's accelerate hooks to retrieve the actual | ||
| weight data and updates state_dict in-place. | ||
| """ | ||
| hook_entries: list[tuple[str, str, Mapping[str, torch.Tensor]]] = [] | ||
|
|
||
| def _weights_map_from_hook(hook_obj): | ||
| """Best-effort extraction of an accelerate weights_map from a hook object.""" | ||
| if hasattr(hook_obj, "weights_map") and hook_obj.weights_map is not None: | ||
| return hook_obj.weights_map | ||
| if hasattr(hook_obj, "hooks"): | ||
| for h in hook_obj.hooks: | ||
| if hasattr(h, "weights_map") and h.weights_map is not None: | ||
| return h.weights_map | ||
| return None | ||
|
|
||
| try: | ||
| # Reuse accelerate plugin hook resolution instead of duplicating traversal logic. | ||
| from modelopt.torch.quantization.plugins.accelerate import _get_cpu_offload_hook | ||
| except ImportError: | ||
| _get_cpu_offload_hook = None | ||
|
|
||
| for name, module in model.named_modules(): | ||
| hook = getattr(module, "_hf_hook", None) | ||
| if hook is None or _get_cpu_offload_hook is None: | ||
| continue | ||
|
|
||
| align_hook = None | ||
| if _get_cpu_offload_hook is not None: | ||
| try: | ||
| align_hook = _get_cpu_offload_hook(hook) | ||
| except AssertionError: | ||
| # Some accelerate hook variants do not expose a plain "weight" key. | ||
| # Fall back to generic weights_map extraction for export-time readout. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The guard # This is dead code — _get_cpu_offload_hook is guaranteed non-None here
if _get_cpu_offload_hook is not None: |
||
| align_hook = None | ||
|
|
||
| wmap = align_hook.weights_map if align_hook is not None else _weights_map_from_hook(hook) | ||
| if wmap is None: | ||
| continue | ||
|
|
||
| if hasattr(wmap, "dataset"): | ||
| weight_prefix = wmap.prefix | ||
| actual_sd = wmap.dataset.state_dict | ||
| else: | ||
| weight_prefix = "" | ||
| actual_sd = wmap | ||
|
|
||
| module_prefix = f"{name}." if name else "" | ||
| hook_entries.append((module_prefix, weight_prefix, actual_sd)) | ||
|
|
||
| # Match most-specific module prefixes first to avoid ambiguous parent-prefix hits. | ||
| hook_entries.sort(key=lambda x: len(x[0]), reverse=True) | ||
|
|
||
| materialized = 0 | ||
| for key in meta_keys: | ||
| for module_prefix, weight_prefix, actual_sd in hook_entries: | ||
| if not key.startswith(module_prefix): | ||
| continue | ||
| local_key = key[len(module_prefix) :] | ||
| lookup_key = weight_prefix + local_key | ||
| if lookup_key in actual_sd: | ||
| state_dict[key] = actual_sd[lookup_key].detach().clone() | ||
| materialized += 1 | ||
| break | ||
| else: | ||
| logger.warning("Could not materialize meta tensor for key: %s", key) | ||
|
|
||
| logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys)) | ||
|
|
||
|
|
||
| def _save_clean_checkpoint( | ||
| model: nn.Module, | ||
| clean_sd: dict[str, torch.Tensor], | ||
| export_dir: Path, | ||
| ) -> None: | ||
| """Save clean weights + config directly, bypassing model.save_pretrained(). | ||
|
|
||
| For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)`` | ||
| ignores the provided state_dict and saves from internal state, leaking | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Behavioral change for ALL models: For the vLLM FakeQuant use case this is probably fine (vLLM doesn't need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _save_clean_checkpoint is now used only when offloaded/meta tensors are detected. |
||
| quantizer keys. This function saves ``clean_sd`` directly via safetensors | ||
| API, guaranteeing only the intended keys are written. | ||
| """ | ||
| import json | ||
|
|
||
| from huggingface_hub import split_torch_state_dict_into_shards | ||
| from safetensors.torch import save_file | ||
|
|
||
| export_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB") | ||
| for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items(): | ||
| # Keep peak memory bounded: move and clone one shard at a time. | ||
| # Cloning also breaks shared storage, which safetensors rejects. | ||
| shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys} | ||
| save_file(shard, str(export_dir / shard_file)) | ||
| logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding |
||
|
|
||
| if state_dict_split.is_sharded: | ||
| index = { | ||
| "metadata": state_dict_split.metadata, | ||
| "weight_map": state_dict_split.tensor_to_filename, | ||
| } | ||
| (export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2)) | ||
|
|
||
| if hasattr(model, "config"): | ||
| config = model.config.to_dict() | ||
| config_path = export_dir / "config.json" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: config = model.config.to_dict()
config.pop("auto_map", None) # Custom code files are not in the export dir |
||
| config_path.write_text(json.dumps(config, indent=2) + "\n") | ||
| logger.info("Saved config.json") | ||
|
|
||
| generation_config = getattr(model, "generation_config", None) | ||
| if generation_config is not None: | ||
| generation_config.save_pretrained(export_dir) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing tokenizer save. Also, some models have additional files saved by |
||
| logger.info( | ||
| "Checkpoint saved: %d weights in %d shard(s)", | ||
| len(clean_sd), | ||
| len(state_dict_split.filename_to_tensors), | ||
| ) | ||
|
|
||
|
|
||
| def export_hf_vllm_fq_checkpoint( | ||
| model: nn.Module, | ||
| export_dir: Path | str, | ||
|
|
@@ -62,6 +195,31 @@ def export_hf_vllm_fq_checkpoint( | |
| # parameters are never modified. Apply each weight quantizer's fake-quant | ||
| # to the corresponding weight tensor in the copy. | ||
| state_dict = model.state_dict() | ||
|
|
||
| # Handle accelerate-offloaded models: state_dict() returns meta tensors | ||
| # for CPU/disk-offloaded layers. Materialize them from the offload hooks. | ||
| meta_keys = [k for k, v in state_dict.items() if v.is_meta] | ||
| has_offloaded_weights = bool(meta_keys) | ||
| if meta_keys: | ||
| logger.info( | ||
| "Found %d meta tensors in state_dict (accelerate offloading). " | ||
| "Materializing from offload hooks...", | ||
| len(meta_keys), | ||
| ) | ||
| _materialize_offloaded_weights(model, state_dict, meta_keys) | ||
| unresolved_meta_keys = [ | ||
| k | ||
| for k, v in state_dict.items() | ||
| if v.is_meta and "quantizer" not in k and "quant" not in k | ||
| ] | ||
| if unresolved_meta_keys: | ||
| shown = ", ".join(unresolved_meta_keys[:10]) | ||
| suffix = " ..." if len(unresolved_meta_keys) > 10 else "" | ||
| raise RuntimeError( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The unresolved meta key check filters out keys containing |
||
| "Failed to materialize offloaded tensors before fake-quant folding / " | ||
| f"_save_clean_checkpoint: {shown}{suffix}" | ||
| ) | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| fakequant_weights = set() | ||
| input_quantizers_folded_pqs = ( | ||
| set() | ||
|
|
@@ -86,6 +244,26 @@ def export_hf_vllm_fq_checkpoint( | |
| ) | ||
| if sd_key in state_dict: | ||
| w = state_dict[sd_key] | ||
| # Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA. | ||
| # Offloaded weights materialized to CPU need a GPU hop. | ||
| if not w.is_cuda: | ||
| # Find a CUDA device from quantizer/module tensors. | ||
| cuda_dev = None | ||
| for t in list(quantizer.parameters()) + list(quantizer.buffers()): | ||
| if t.is_cuda: | ||
| cuda_dev = t.device | ||
| break | ||
| if cuda_dev is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential silent failure: If no CUDA device is found (e.g., all quantizer buffers/params and module params happen to be on CPU/meta), if cuda_dev is None:
raise RuntimeError(
f"Cannot find CUDA device for quantizer kernel on offloaded weight '{sd_key}'. "
"Ensure at least one quantizer buffer or module parameter is on CUDA."
)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added explicit error |
||
| for t in list(module.parameters()) + list(module.buffers()): | ||
| if t.is_cuda: | ||
| cuda_dev = t.device | ||
| break | ||
| if cuda_dev is None: | ||
| raise RuntimeError( | ||
| "Cannot find CUDA device for quantizer kernel on offloaded weight " | ||
| f"'{sd_key}'. Ensure at least one quantizer/module tensor is on CUDA." | ||
| ) | ||
| w = w.to(cuda_dev) | ||
| w_quant = quantizer(w.float()).to(w.dtype).cpu() | ||
| # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) | ||
| # Only valid when input_quantizer does NOT fake-quant activations. If it does | ||
|
|
@@ -117,53 +295,59 @@ def export_hf_vllm_fq_checkpoint( | |
| # Rotation is also cleared: the weight was already folded with rotation applied, | ||
| # so if fold_weight is called on reload it must not re-rotate the exported weight. | ||
| wqs_to_restore = [] | ||
| for _, module in model.named_modules(): | ||
| if isinstance(module, QuantModule): | ||
| for attr_name, quantizer in module.named_children(): | ||
| if ( | ||
| attr_name.endswith("weight_quantizer") | ||
| and isinstance(quantizer, TensorQuantizer) | ||
| and quantizer.is_enabled | ||
| ): | ||
| quantizer.disable() | ||
| orig_rotate = quantizer._rotate | ||
| if quantizer.rotate_is_enabled: | ||
| quantizer._rotate = disable_rotate(quantizer) | ||
| wqs_to_restore.append((quantizer, orig_rotate)) | ||
|
|
||
| quantizer_state_dict = get_quantizer_state_dict(model) | ||
| for key in list(quantizer_state_dict): | ||
| if key.endswith("weight_quantizer"): | ||
| # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. | ||
| quantizer_state_dict.pop(key) | ||
| elif key in input_quantizers_folded_pqs: | ||
| # pre_quant_scale was folded into the weight; keep the buffer for strict load but | ||
| # save identity so activations are not scaled twice. | ||
| qstate_val = quantizer_state_dict[key] | ||
| if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: | ||
| quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( | ||
| qstate_val["_pre_quant_scale"] | ||
| ) | ||
| modelopt_state = mto.modelopt_state(model) | ||
| # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild | ||
| # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). | ||
| qstate = quantizer_state(model) | ||
| for key in list(qstate): | ||
| if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): | ||
| qstate.pop(key) | ||
|
|
||
| for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): | ||
| if mode_str == "quantize" and "metadata" in m_state: | ||
| m_state["metadata"]["quantizer_state"] = qstate | ||
| break | ||
|
|
||
| # Per-quantizer tensor dict loaded alongside metadata on reload. | ||
| modelopt_state["modelopt_state_weights"] = quantizer_state_dict | ||
| torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") | ||
|
|
||
| # Step 3: Save HF weights using the pre-built folded state dict. | ||
| model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) | ||
|
|
||
| for wq, orig_rotate in wqs_to_restore: | ||
| wq.enable() | ||
| wq._rotate = orig_rotate | ||
| try: | ||
| for _, module in model.named_modules(): | ||
| if isinstance(module, QuantModule): | ||
| for attr_name, quantizer in module.named_children(): | ||
| if ( | ||
| attr_name.endswith("weight_quantizer") | ||
| and isinstance(quantizer, TensorQuantizer) | ||
| and quantizer.is_enabled | ||
| ): | ||
| quantizer.disable() | ||
| orig_rotate = quantizer._rotate | ||
| if quantizer.rotate_is_enabled: | ||
| quantizer._rotate = disable_rotate(quantizer) | ||
| wqs_to_restore.append((quantizer, orig_rotate)) | ||
|
|
||
| quantizer_state_dict = get_quantizer_state_dict(model) | ||
| for key in list(quantizer_state_dict): | ||
| if key.endswith("weight_quantizer"): | ||
| # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. | ||
| quantizer_state_dict.pop(key) | ||
| elif key in input_quantizers_folded_pqs: | ||
| # pre_quant_scale was folded into the weight; keep the buffer for strict load but | ||
| # save identity so activations are not scaled twice. | ||
| qstate_val = quantizer_state_dict[key] | ||
| if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: | ||
| quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( | ||
| qstate_val["_pre_quant_scale"] | ||
| ) | ||
| modelopt_state = mto.modelopt_state(model) | ||
| # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild | ||
| # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). | ||
| qstate = quantizer_state(model) | ||
| for key in list(qstate): | ||
| if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): | ||
| qstate.pop(key) | ||
|
|
||
| for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): | ||
| if mode_str == "quantize" and "metadata" in m_state: | ||
| m_state["metadata"]["quantizer_state"] = qstate | ||
| break | ||
|
|
||
| # Per-quantizer tensor dict loaded alongside metadata on reload. | ||
| modelopt_state["modelopt_state_weights"] = quantizer_state_dict | ||
| torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") | ||
|
|
||
| # Step 3: Save HF weights. | ||
| # Accelerate-offloaded models may ignore state_dict= in save_pretrained() | ||
| # and leak quantizer keys, so use manual save only in that case. | ||
| if has_offloaded_weights: | ||
| _save_clean_checkpoint(model, clean_sd, export_dir) | ||
| else: | ||
| model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) | ||
| finally: | ||
| for wq, orig_rotate in wqs_to_restore: | ||
| wq.enable() | ||
| wq._rotate = orig_rotate | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: When
_get_cpu_offload_hookisNone(ImportError), the function will skip all modules with_hf_hook, meaning no weights get materialized. This will then trigger the RuntimeError downstream for any non-quantizer meta keys, which is fine. But a more informative early error or warning when accelerate plugin import fails would help debugging: