-
Notifications
You must be signed in to change notification settings - Fork 6.7k
ONNX export: Square and ReduceSum operators #12653
Conversation
|
@mxnet-label-bot[pr-awaiting-review] |
f228e06 to
3d9372d
Compare
anirudhacharya
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should wait on #12633 and please add tests for these operators.
|
Rebased on top of #12633. @anirudhacharya @Roshrini @zhreshold This PR is ready for review again. |
| mx_axis = attrs.get("axis", None) | ||
| axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None | ||
|
|
||
| keepdims = 1 if ("keepdims" in attrs) and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need the first condition ("keepdims" in attrs) here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's an optional parameter. so added this check just in case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, my point was since you are doing attrs.get("keepdims") the first condition is a redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, considering changing it to
keepdims = attrs.get("keepdims", 0)
keepdims = 1 if keepdims in ["True", "1"] else 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking of this -
keepdims = 1 if attrs.get("keepdims", 0) in ["True", "1"] else 0
but this should be ok i guess.
| axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None | ||
|
|
||
| keepdims = 1 if ("keepdims" in attrs) and \ | ||
| attrs.get("keepdims") in ["True", "1"] else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't attrs.get("keepdims") be cast to str before making this comparison? I am curious how it passed the unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the value has always been "True" or "1" or "0" when it reaches this function. same is the case with similar functions such as Reduce*, Arg*. I will figure out where this change in type occurs and get back to you.
| keepdims=keepdims, | ||
| name=name | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you remove new line here and in line 2150. Or maybe you could have a single return statement after the if else block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will change it to a common return statement
| "Pow", | ||
| [input_node_a, power2_name], | ||
| [name], | ||
| name=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why name=None here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was a copy from Pow operator. No reason to have it as None. Will add a name. and change it in Pow as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like name wasn't added to Pow because of the error "pow() got an unexpected keyword argument 'name'". Submitting the fix for this in this PR itself (separate commit).
|
LGTM |
|
@anirudhacharya @zhreshold ping for review |
| new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast']) | ||
| return 'broadcast_power', new_attrs, inputs | ||
| return 'pow', new_attrs, inputs | ||
| mxnet_op = symbol.pow(inputs[0], inputs[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this change? you probably made this change to accommodate op_set version 7 of ONNX. will we support op_set versions older than 7?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverting back to the older commit. Will figure out supporting different op_set versions and work on the Power operator separately.
|
@zhreshold @Roshrini @anirudhacharya review comments addressed. is this PR good to go? |
anirudhacharya
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
| inputs = node["inputs"] | ||
|
|
||
| input_node_a_id = kwargs["index_lookup"][inputs[0][0]] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: in general feel that there are a lot of unnecessary blank lines in this method. also you can remove line 2096 have np.array([2]) directly. but this change is not a blocker
|
|
||
| initializer = kwargs["initializer"] | ||
| power2 = [2] | ||
| np_arr = np.array(power2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the point of setting power2 as a temporary variable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Re-submitting after editing this.
| axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None | ||
|
|
||
| keepdims = attrs.get("keepdims", 0) | ||
| keepdims = 1 if keepdims in ["True", "1"] else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should reuse the function in #12646
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Re-submitting with this change.
Description
Enabling export of Square operator and Sum operator
v2: Added reduce_sum tests based on @anirudhacharya's review
v3: Addressed @Roshrini's and @zhreshold's review comments
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments