-
Notifications
You must be signed in to change notification settings - Fork 561
Std mean #3004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
JackCaoG
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM, minor nits. Please fix the format too.
torch_xla/csrc/ops/std_mean.cpp
Outdated
| xla::XlaOp LowerStd(xla::XlaOp input, | ||
| const std::vector<xla::int64>& dimensions, | ||
| bool keep_reduced_dimensions, | ||
| xla::int64 correction) { | ||
| return BuildStdDeviation(input, dimensions, keep_reduced_dimensions, correction); | ||
| } | ||
|
|
||
| xla::XlaOp LowerMean(xla::XlaOp input, | ||
| const std::vector<xla::int64>& dimensions, | ||
| bool keep_reduced_dimensions) { | ||
| return BuildMean(input, dimensions, keep_reduced_dimensions); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't really need these two helper function if they are just forwarding argument to BuildStdDeviation and BuildMean. You only replace the use of these two function with Buildxxx directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
torch_xla/csrc/ops/std_mean.cpp
Outdated
| bool keep_reduced_dimensions, | ||
| xla::int64 correction) { | ||
| auto lower_for_shape_fn_std_mean = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp { | ||
| auto std = LowerStd(operands[0], dimensions, keep_reduced_dimensions, correction); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try not to use auto when we can 😄 , I think in general we only use it when the type name is super long and complicated. In this case we can just use xla::XlaOp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torch_xla/csrc/ops/std_mean.cpp
Outdated
| xla::int64 correction) { | ||
| auto lower_for_shape_fn_std_mean = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp { | ||
| auto std = LowerStd(operands[0], dimensions, keep_reduced_dimensions, correction); | ||
| auto mean = LowerMean(operands[0], dimensions, keep_reduced_dimensions); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
JackCaoG
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @miladm, minor nits. You can merge it once you fix the comment and verify all test passed.
| ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); | ||
| ExpectCounterChanged("xla::std_mean", cpp_test::GetIgnoredCounters()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT, move this two check our of the for loop. We only need to check once at the end of the test, instead of for every execution.
The PR draft for std_mean.
The PR is to be tested.
Issue link