Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

Summary:
The correct clipped value uses fp32 regardless of the input dtype. Let's make it follow the input dtype.

Test Plan:
python test/test_operations.py -v -k test_clip_grad_norm_

@alanwaketan alanwaketan requested a review from JackCaoG June 5, 2024 23:19
@alanwaketan alanwaketan self-assigned this Jun 5, 2024
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
dtype = parameters[0].grad.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh man.... I totally forgot we have this patch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it actually is needed though... haha

@alanwaketan
Copy link
Collaborator Author

Thanks, Jack!

@alanwaketan alanwaketan merged commit f5c0062 into master Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants