Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Mar 22, 2025

This fixes the NaN when calculating the Hessian in the JAX backend. Reproduce NaN:

from deepmd.pt.utils.serialization import serialize_from_file
from deepmd.jax.model.ener_model import EnergyModel
from deepmd.jax.env import jax, jnp
import numpy as np
jax.config.update("jax_debug_nans", True)

model = serialize_from_file('frozen_model.pth')
model = EnergyModel.deserialize(model["model"])
model.enable_hessian()
model_call = jax.jit(model.call)
# nframes x natoms x 3
coord = np.array([[[0,0,0],[1,1,1]]], dtype=np.float64)
# nframes x natoms
atype = np.array([[0,1]], dtype=int)
print(model_call(jnp.array(coord), jnp.array(atype), None)['energy_derv_r_derv_r'])

Summary by CodeRabbit

  • Bug Fixes
    • Improved numerical stability by enhancing the vector normalization process to better handle edge cases and ensure robust computations.

This fixes the NaN when calculating the Hessian in the JAX backend.
Reproduce NaN:
```py
from deepmd.pt.utils.serialization import serialize_from_file
from deepmd.jax.model.ener_model import EnergyModel
from deepmd.jax.env import jax, jnp
import numpy as np
jax.config.update("jax_debug_nans", True)

model = serialize_from_file('frozen_model.pth')
model = EnergyModel.deserialize(model["model"])
model.enable_hessian()
model_call = jax.jit(model.call)
# nframes x natoms x 3
coord = np.array([[[0,0,0],[1,1,1]]], dtype=np.float64)
# nframes x natoms
atype = np.array([[0,1]], dtype=int)
print(model_call(jnp.array(coord), jnp.array(atype), None)['energy_derv_r_derv_r'])
```

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes the NaN issue encountered when calculating the Hessian in the JAX backend by replacing direct calls to vector_norm with the safe_for_vector_norm function.

  • Replaces a manual safeguard in env_mat.py with safe_for_vector_norm.
  • Updates repflows.py to use safe_for_vector_norm for computing normalized differences.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
deepmd/dpmodel/utils/env_mat.py Replaces manual vector norm safeguard with safe_for_vector_norm.
deepmd/dpmodel/descriptor/repflows.py Updates normalized difference calculation to use safe_for_vector_norm.

@njzjz njzjz changed the title fix(jax): use safe_for_vector_norm for env mat and dpa3 fix(jax): fix Hessian NaN for DPA-3 Mar 22, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 22, 2025

📝 Walkthrough

Walkthrough

This pull request replaces direct calls to xp.linalg.vector_norm with a new, safer alternative—safe_for_vector_norm—in three components. The DescrptBlockRepflows and DescrptBlockRepformers classes, as well as the _make_env_mat function, have been updated to utilize the imported utility, which is designed to better handle edge cases when dealing with small or near-zero values. Additionally, the safe_for_vector_norm function is added to the public API in the safe gradient utility module.

Changes

File(s) Change Summary
deepmd/dpmodel/descriptor/repflows.py, deepmd/dpmodel/descriptor/repformers.py Replaced direct calls to xp.linalg.vector_norm with an imported safe_for_vector_norm to improve robustness in vector normalization computations.
deepmd/dpmodel/utils/env_mat.py Modified the calculation of length to use safe_for_vector_norm instead of a conditional operation with xp.linalg.vector_norm.
deepmd/dpmodel/utils/safe_gradient Added the new function safe_for_vector_norm(diff, axis=-1, keepdims=True) to encapsulate safer handling of vector norm calculations, addressing edge cases.

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd2b832 and 62c3f2b.

📒 Files selected for processing (1)
  • deepmd/dpmodel/descriptor/repformers.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (29)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/dpmodel/descriptor/repformers.py (2)

31-33: Good addition of the safe_for_vector_norm utility

Adding this import from the safe_gradient module will help prevent potential NaN values in numerical calculations, which is critical for molecular dynamics simulations.


420-420: Excellent fix for numerical stability

Replacing the direct vector norm calculation with safe_for_vector_norm is a critical fix that addresses the PR's main objective - preventing NaN values in Hessian calculations. This change properly handles edge cases when dealing with vectors of very small magnitude during normalization.

This modification will significantly improve the robustness of the JAX backend for DPA-3 (Deep Potential A3) by ensuring numerical stability in gradient calculations.

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c9bfa79 and bd2b832.

📒 Files selected for processing (2)
  • deepmd/dpmodel/descriptor/repflows.py (2 hunks)
  • deepmd/dpmodel/utils/env_mat.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (29)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (4)
deepmd/dpmodel/utils/env_mat.py (2)

16-18: Good addition of the new import for safe gradients.

The addition of safe_for_vector_norm from the deepmd.dpmodel.utils.safe_gradient module will help address NaN values in gradient calculations when vector norms are computed at zero vectors.


64-64: Excellent fix for preventing NaN gradients in JAX backend.

This change directly addresses the PR's objective by replacing the standard vector norm calculation with a safer version that can handle zero-length vectors. The previous implementation was causing NaN values in the gradient calculation when the input vector was zero, as indicated by the comment on line 63.

The implementation now uses safe_for_vector_norm instead of directly calling xp.linalg.vector_norm, which should prevent NaN values in the Hessian calculations while preserving the same numerical behavior for non-zero vectors.

deepmd/dpmodel/descriptor/repflows.py (2)

30-32: Good addition of the new import for safe gradients.

The addition of safe_for_vector_norm from the deepmd.dpmodel.utils.safe_gradient module will help address NaN values in gradient calculations when vector norms are computed at zero vectors.


421-421: Excellent fix for preventing NaN gradients in JAX backend.

This change replaces the standard vector norm calculation with a safer version that can handle zero-length vectors. Previously, the code added a small constant (1e-6) to prevent division by zero, but that wasn't sufficient to prevent NaN gradients in JAX. Using safe_for_vector_norm provides a more robust solution.

Note that the small constant (1e-6) is still maintained after the call to safe_for_vector_norm. This is good for ensuring numerical stability in the subsequent operations, even with the safer vector norm implementation.

@codecov
Copy link

codecov bot commented Mar 22, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.81%. Comparing base (c9bfa79) to head (62c3f2b).
Report is 81 commits behind head on devel.

Additional details and impacted files
@@           Coverage Diff           @@
##            devel    #4668   +/-   ##
=======================================
  Coverage   84.81%   84.81%           
=======================================
  Files         692      692           
  Lines       66360    66361    +1     
  Branches     3539     3540    +1     
=======================================
+ Hits        56282    56283    +1     
+ Misses       8937     8936    -1     
- Partials     1141     1142    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz requested review from iProzd and removed request for iProzd March 25, 2025 12:37
@njzjz njzjz enabled auto-merge March 25, 2025 13:24
@njzjz njzjz added this pull request to the merge queue Mar 25, 2025
Merged via the queue into deepmodeling:devel with commit 089995b Mar 25, 2025
60 checks passed
@njzjz njzjz deleted the fix-jax-dpa3-nan branch March 25, 2025 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants