-
Notifications
You must be signed in to change notification settings - Fork 581
fix(jax): fix Hessian NaN for DPA-3 #4668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This fixes the NaN when calculating the Hessian in the JAX backend.
Reproduce NaN:
```py
from deepmd.pt.utils.serialization import serialize_from_file
from deepmd.jax.model.ener_model import EnergyModel
from deepmd.jax.env import jax, jnp
import numpy as np
jax.config.update("jax_debug_nans", True)
model = serialize_from_file('frozen_model.pth')
model = EnergyModel.deserialize(model["model"])
model.enable_hessian()
model_call = jax.jit(model.call)
# nframes x natoms x 3
coord = np.array([[[0,0,0],[1,1,1]]], dtype=np.float64)
# nframes x natoms
atype = np.array([[0,1]], dtype=int)
print(model_call(jnp.array(coord), jnp.array(atype), None)['energy_derv_r_derv_r'])
```
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes the NaN issue encountered when calculating the Hessian in the JAX backend by replacing direct calls to vector_norm with the safe_for_vector_norm function.
- Replaces a manual safeguard in env_mat.py with safe_for_vector_norm.
- Updates repflows.py to use safe_for_vector_norm for computing normalized differences.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| deepmd/dpmodel/utils/env_mat.py | Replaces manual vector norm safeguard with safe_for_vector_norm. |
| deepmd/dpmodel/descriptor/repflows.py | Updates normalized difference calculation to use safe_for_vector_norm. |
safe_for_vector_norm for env mat and dpa3
📝 WalkthroughWalkthroughThis pull request replaces direct calls to Changes
Possibly related PRs
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms (29)
🔇 Additional comments (2)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/dpmodel/descriptor/repflows.py(2 hunks)deepmd/dpmodel/utils/env_mat.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (true)
🔇 Additional comments (4)
deepmd/dpmodel/utils/env_mat.py (2)
16-18: Good addition of the new import for safe gradients.The addition of
safe_for_vector_normfrom thedeepmd.dpmodel.utils.safe_gradientmodule will help address NaN values in gradient calculations when vector norms are computed at zero vectors.
64-64: Excellent fix for preventing NaN gradients in JAX backend.This change directly addresses the PR's objective by replacing the standard vector norm calculation with a safer version that can handle zero-length vectors. The previous implementation was causing NaN values in the gradient calculation when the input vector was zero, as indicated by the comment on line 63.
The implementation now uses
safe_for_vector_norminstead of directly callingxp.linalg.vector_norm, which should prevent NaN values in the Hessian calculations while preserving the same numerical behavior for non-zero vectors.deepmd/dpmodel/descriptor/repflows.py (2)
30-32: Good addition of the new import for safe gradients.The addition of
safe_for_vector_normfrom thedeepmd.dpmodel.utils.safe_gradientmodule will help address NaN values in gradient calculations when vector norms are computed at zero vectors.
421-421: Excellent fix for preventing NaN gradients in JAX backend.This change replaces the standard vector norm calculation with a safer version that can handle zero-length vectors. Previously, the code added a small constant (1e-6) to prevent division by zero, but that wasn't sufficient to prevent NaN gradients in JAX. Using
safe_for_vector_normprovides a more robust solution.Note that the small constant (1e-6) is still maintained after the call to
safe_for_vector_norm. This is good for ensuring numerical stability in the subsequent operations, even with the safer vector norm implementation.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4668 +/- ##
=======================================
Coverage 84.81% 84.81%
=======================================
Files 692 692
Lines 66360 66361 +1
Branches 3539 3540 +1
=======================================
+ Hits 56282 56283 +1
+ Misses 8937 8936 -1
- Partials 1141 1142 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jinzhe Zeng <[email protected]>
This fixes the NaN when calculating the Hessian in the JAX backend. Reproduce NaN:
Summary by CodeRabbit