From c4bb212e0f6171926d126fcab456e4b0cfdf6bb8 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:28:28 -0800 Subject: [PATCH] [ET-VK] Rearranging code in permute op shader to reduce heavy math ops and improve performance. The diff rearranges Permute op shader code in executorch vulkan backend to reduce heavy math operations and improve performance. The change also include adding an extension to support explicit arithmetic types and changing the data type of the position variable to u16vec3. Differential Revision: [D66174765](https://our.internmc.facebook.com/intern/diff/D66174765/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/permute.glsl | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 8414d811fc8..5378099d03f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -36,8 +36,10 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + const u16vec3 pos = u16vec3(gl_GlobalInvocationID); if (any(greaterThanEqual(pos, out_limits))) { return; @@ -46,28 +48,34 @@ void main() { const int out_channel_4up = int(ch_info.x); const int in_channel_4up = int(ch_info.y); const int out_batch = int(sizes[3]); - const int max_dst_index = out_batch * out_channel_4up; VEC4_T outval = VEC4_T(0.0); + ivec4 v = ivec4(0); // holds b,c,h,w + + v[out_ndims[2]] = pos.y; + v[out_ndims[3]] = pos.x; + + const int dst_index = pos.z << 2; + int dst_out_index = dst_index / out_channel_4up; + int dst_out_lane = dst_index % out_channel_4up; - for (int j = 0; j < 4; ++j) { - int dst_index = pos.z * 4 + j; - if (dst_index >= max_dst_index) { + for (int j = 0; j < 4; ++j, ++dst_out_lane) { + if (dst_out_index >= out_batch) { // out of range break; } - ivec4 v = ivec4(0); // holds b,c,h,w - v[out_ndims[0]] = dst_index / out_channel_4up; - v[out_ndims[1]] = dst_index % out_channel_4up; - v[out_ndims[2]] = pos.y; - v[out_ndims[3]] = pos.x; + if (dst_out_lane == out_channel_4up) { + dst_out_lane = 0; + dst_out_index++; + } + + v[out_ndims[0]] = dst_out_index; + v[out_ndims[1]] = dst_out_lane; int src_index = v[0] * in_channel_4up + v[1]; - int w = v[3]; - int h = v[2]; - VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0)); - outval[j] = inval[src_index % 4]; + VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(v[3], v[2], src_index >> 2), 0)); + outval[j] = inval[src_index & 0x3]; } imageStore(image_out, pos, outval);