Added HiFi optimized mean and where ops.#6483
Conversation
Adding mean and where ops optimized on HiFi
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6483
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| int prepare_data( | ||
| const Tensor& in, | ||
| Tensor& out, |
hsharma35
left a comment
There was a problem hiding this comment.
LGTM overall. Added some minor comments for readability.
| constexpr auto name = "mean.out"; | ||
| constexpr int kNnlibMaxDim = 4; | ||
|
|
||
| bool optimized = 1; |
There was a problem hiding this comment.
Nit: Please use true or false instead of 0 or 1
| for (int i = 0; i < num_inp_dims; i++) { | ||
| inp_shape[i] = in.size(i); | ||
| } | ||
|
|
||
| for (int i = 0; i < num_out_dims; i++) { | ||
| out_shape[i] = out.size(i); | ||
| } |
There was a problem hiding this comment.
Nit: can we make this a helper function that accepts a const Tensor& and mutable int*
| optional<ArrayRef<int64_t>> dim_list, | ||
| int* inp_shape, | ||
| int* out_shape, | ||
| int* p_axis, |
There was a problem hiding this comment.
Nit: What is p_axis used for? Can we please rename this to make the usage obvious (or add comment)?
| int scratch_size = xa_nn_reduce_getsize_nhwc( | ||
| -3, inp_shape, num_inp_dims, p_axis, num_axis_dims, 1); | ||
|
|
||
| void* __restrict__ p_scratch_in = (void* __restrict__)malloc(scratch_size); |
There was a problem hiding this comment.
Let's use temporary allocator instead of malloc please.
|
|
||
| int a_dim = a.dim(), b_dim = b.dim(), con_dim = cond.dim(), | ||
| out_dim = out.dim(); | ||
| bool optimized = 1; |
There was a problem hiding this comment.
Nit: use true or false instead of integers.
| for (int i = 0; i < 4; i++) { | ||
| con_shape[i] = out_shape[i]; | ||
| } | ||
| xa_nn_elm_where_broadcast_4D_f32xf32_f32( |
There was a problem hiding this comment.
ret = nnlib(...)
ET_KERNEL_CHECK(
ctx,
ret == 0,
ERROR_CODE,
out);
| con_shape); | ||
| free(p_scratch); | ||
| } else { | ||
| xa_nn_elm_where_broadcast_4D_f32xf32_f32( |
There was a problem hiding this comment.
ret = nnlib(...)
ET_KERNEL_CHECK(
ctx,
ret == 0,
ERROR_CODE,
out);
| con_shape); | ||
| } | ||
| } else { | ||
| xa_nn_elm_where_f32xf32_f32(out_data, a_data, b_data, con, out.numel()); |
There was a problem hiding this comment.
ret = nnlib(...)
ET_KERNEL_CHECK(
ctx,
ret == 0,
ERROR_CODE,
out);
| "Unhandled dtype %s for where.self_out", | ||
| torch::executor::toString(cond_type)); | ||
|
|
||
| int a_dim = a.dim(), b_dim = b.dim(), con_dim = cond.dim(), |
There was a problem hiding this comment.
Nit: rename con_dim -> cond_dim
| if (optimized) { | ||
| const float* a_data = a.const_data_ptr<float>(); | ||
| const float* b_data = b.const_data_ptr<float>(); | ||
| float* out_data = out.mutable_data_ptr<float>(); | ||
| const unsigned char* con = cond.const_data_ptr<uint8_t>(); | ||
|
|
||
| if (broadcast == 1) { | ||
| int out_shape[kNnlibMaxDim]; | ||
| int inp1_shape[kNnlibMaxDim]; | ||
| int inp2_shape[kNnlibMaxDim]; | ||
| int con_shape[kNnlibMaxDim]; | ||
|
|
||
| for (int i = 0; i < kNnlibMaxDim; i++) { | ||
| con_shape[i] = 1; | ||
| out_shape[i] = 1; | ||
| inp1_shape[i] = 1; | ||
| inp2_shape[i] = 1; | ||
| } | ||
|
|
||
| int off_o = kNnlibMaxDim - out.dim(); | ||
| int off_a = kNnlibMaxDim - a.dim(); | ||
| int off_b = kNnlibMaxDim - b.dim(); | ||
| int off_c = kNnlibMaxDim - cond.dim(); | ||
|
|
||
| for (int i = 0; i < out.dim(); i++) | ||
| out_shape[i + off_o] = out.size(i); | ||
| for (int i = 0; i < a.dim(); i++) | ||
| inp1_shape[i + off_a] = a.size(i); | ||
| for (int i = 0; i < b.dim(); i++) | ||
| inp2_shape[i + off_b] = b.size(i); | ||
| for (int i = 0; i < cond.dim(); i++) | ||
| con_shape[i + off_c] = cond.size(i); | ||
|
|
||
| if (con_shape[0] != out_shape[0] || con_shape[1] != out_shape[1] || | ||
| con_shape[2] != out_shape[2] || con_shape[3] != out_shape[3]) { | ||
| void* p_scratch = | ||
| malloc(out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]); | ||
| const unsigned char* p_brd_cond = (const unsigned char*)p_scratch; | ||
| xa_nn_broadcast_8_8( | ||
| (WORD8* __restrict__)p_brd_cond, | ||
| out_shape, | ||
| (const WORD8* __restrict__)con, | ||
| con_shape, | ||
| 4); | ||
|
|
||
| for (int i = 0; i < 4; i++) { | ||
| con_shape[i] = out_shape[i]; | ||
| } | ||
| xa_nn_elm_where_broadcast_4D_f32xf32_f32( | ||
| out_data, | ||
| out_shape, | ||
| a_data, | ||
| inp1_shape, | ||
| b_data, | ||
| inp2_shape, | ||
| p_brd_cond, | ||
| con_shape); | ||
| free(p_scratch); | ||
| } else { | ||
| xa_nn_elm_where_broadcast_4D_f32xf32_f32( | ||
| out_data, | ||
| out_shape, | ||
| a_data, | ||
| inp1_shape, | ||
| b_data, | ||
| inp2_shape, | ||
| con, | ||
| con_shape); | ||
| } | ||
| } else { | ||
| xa_nn_elm_where_f32xf32_f32(out_data, a_data, b_data, con, out.numel()); | ||
| } | ||
| return out; | ||
| } |
There was a problem hiding this comment.
Can we move this to a separate inline function?
No description provided.