From 713ad1dd21dea6bfd8e030db9c7d1d9901951e15 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 27 Dec 2024 22:00:01 -0800 Subject: [PATCH] [ET-VK] Modify conv 2d pw op shader and dispatch settings to linearly dispatch work accounting for linearity texture to improve performance. This diff modifies the convolution 2D pointwise op shader and dispatch settings to linearly dispatch work accounting for linearity texture to improve performance. Differential Revision: [D67683411](https://our.internmc.facebook.com/intern/diff/D67683411/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl | 7 ++++++- backends/vulkan/runtime/graph/ops/impl/Convolution.cpp | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index b1950f970e4..9d1f6c3bd91 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -40,7 +40,12 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; * size is only 1x1, making it easier to re-use loaded texels from t_kernel. */ void main() { - const u16vec3 gpos = u16vec3(gl_GlobalInvocationID); + const uint16_t out_limits_y_scaled = uint16_t((out_limits.y + TILE_SIZE - 1) / TILE_SIZE); + + const u16vec3 gpos = u16vec3( + gl_GlobalInvocationID.x / (out_limits_y_scaled * out_limits.z), + (gl_GlobalInvocationID.x / out_limits.z) % out_limits_y_scaled, + gl_GlobalInvocationID.x % out_limits.z); // Output position for TILE_SIZE = 2 // +--------+--------+ diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 6e9adf7d5a2..4f123cb8337 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -372,6 +372,10 @@ void add_conv2d_node( utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out); + if (method == Conv2dMethod::Pointwise) { + wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1}; + } + graph.execute_nodes().emplace_back(new DispatchNode( graph, shader,