1010
1111#include < ATen/native/vulkan/api/api.h>
1212
13- #include < ATen/native/vulkan/impl/Arithmetic.h>
14- #include < ATen/native/vulkan/impl/Packing.h>
13+ #include < executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
1514
1615#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1716#include < executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
2120
2221using namespace at ::native::vulkan;
2322
24- //
25- // Utilities
26- //
27-
2823#define CREATE_FLOAT_TEXTURE (sizes, allocate_memory ) \
2924 vTensor ( \
3025 api::context (), \
@@ -43,23 +38,159 @@ using namespace at::native::vulkan;
4338 api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \
4439 allocate_memory);
4540
41+ //
42+ // Simplified versions of ATen Vulkan legacy functions
43+ //
44+
45+ void record_nchw_to_buffer_op (
46+ api::Context* const context,
47+ api::VulkanBuffer& src_buffer,
48+ vTensor& v_dst) {
49+ uint32_t buf_len = api::utils::safe_downcast<uint32_t >(v_dst.gpu_numel ());
50+ api::utils::uvec3 global_size = {buf_len, 1u , 1u };
51+ api::utils::uvec3 local_size = {32u , 1u , 1u };
52+
53+ api::UniformParamsBuffer cpu_buffer_metadata (
54+ context, v_dst.get_cpu_buffer_metadata ());
55+ api::PipelineBarrier pipeline_barrier{};
56+
57+ context->submit_compute_job (
58+ VK_KERNEL (buffer_to_buffer),
59+ pipeline_barrier,
60+ global_size,
61+ local_size,
62+ VK_NULL_HANDLE,
63+ v_dst.buffer (
64+ pipeline_barrier,
65+ api::PipelineStage::COMPUTE,
66+ api::MemoryAccessType::WRITE),
67+ v_dst.buffer_metadata (),
68+ src_buffer,
69+ cpu_buffer_metadata.buffer ());
70+ }
71+
72+ bool record_buffer_to_nchw_op (
73+ api::Context* const context,
74+ vTensor& v_src,
75+ api::VulkanBuffer& dst_buffer) {
76+ uint32_t buf_len = api::utils::safe_downcast<uint32_t >(v_src.numel ());
77+ api::utils::uvec3 global_size = {buf_len, 1u , 1u };
78+ api::utils::uvec3 local_size = {4u , 1u , 1u };
79+
80+ api::UniformParamsBuffer cpu_buffer_metadata (
81+ context, v_src.get_cpu_buffer_metadata ());
82+ api::PipelineBarrier pipeline_barrier{};
83+
84+ return context->submit_compute_job (
85+ VK_KERNEL (buffer_to_buffer),
86+ pipeline_barrier,
87+ global_size,
88+ local_size,
89+ VK_NULL_HANDLE,
90+ dst_buffer,
91+ cpu_buffer_metadata.buffer (),
92+ v_src.buffer (
93+ pipeline_barrier,
94+ api::PipelineStage::COMPUTE,
95+ api::MemoryAccessType::WRITE),
96+ v_src.buffer_metadata ());
97+ }
98+
99+ void record_nchw_to_image_op (
100+ api::Context* const context,
101+ api::VulkanBuffer& src_buffer,
102+ vTensor& v_dst) {
103+ api::utils::uvec3 global_size = v_dst.extents ();
104+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
105+
106+ api::UniformParamsBuffer params (context, create_staging_params (v_dst));
107+ api::PipelineBarrier pipeline_barrier{};
108+
109+ context->submit_compute_job (
110+ get_nchw_to_image_shader (v_dst),
111+ pipeline_barrier,
112+ global_size,
113+ local_size,
114+ VK_NULL_HANDLE,
115+ v_dst.image (
116+ pipeline_barrier,
117+ api::PipelineStage::COMPUTE,
118+ api::MemoryAccessType::WRITE),
119+ src_buffer,
120+ params.buffer ());
121+ }
122+
123+ bool record_image_to_nchw_op (
124+ api::Context* const context,
125+ vTensor& v_src,
126+ api::VulkanBuffer& dst_buffer) {
127+ api::utils::uvec3 global_size = v_src.extents ();
128+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
129+
130+ api::UniformParamsBuffer params (context, create_staging_params (v_src));
131+ api::PipelineBarrier pipeline_barrier{};
132+
133+ return context->submit_compute_job (
134+ get_image_to_nchw_shader (v_src),
135+ pipeline_barrier,
136+ global_size,
137+ local_size,
138+ VK_NULL_HANDLE,
139+ v_src.image (
140+ pipeline_barrier,
141+ api::PipelineStage::COMPUTE,
142+ api::MemoryAccessType::WRITE),
143+ dst_buffer,
144+ params.buffer ());
145+ }
146+
147+ void record_arithmetic_op (
148+ api::Context* const context,
149+ const api::ShaderInfo& compute_shader,
150+ vTensor& v_in1,
151+ vTensor& v_in2,
152+ vTensor& v_dst,
153+ const float alpha) {
154+ api::utils::uvec3 global_size = v_dst.extents ();
155+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
156+
157+ ArithmeticParams block{
158+ get_size_as_ivec4 (v_dst),
159+ get_size_as_ivec4 (v_in1),
160+ get_size_as_ivec4 (v_in2),
161+ alpha,
162+ };
163+ api::UniformParamsBuffer params (context, block);
164+ api::PipelineBarrier pipeline_barrier{};
165+
166+ context->submit_compute_job (
167+ compute_shader,
168+ pipeline_barrier,
169+ global_size,
170+ local_size,
171+ VK_NULL_HANDLE,
172+ v_dst.image (
173+ pipeline_barrier,
174+ api::PipelineStage::COMPUTE,
175+ api::MemoryAccessType::WRITE),
176+ v_in1.image (pipeline_barrier, api::PipelineStage::COMPUTE),
177+ v_in2.image (pipeline_barrier, api::PipelineStage::COMPUTE),
178+ params.buffer ());
179+ }
180+
181+ //
182+ // Utilities
183+ //
184+
46185void fill_vtensor (vTensor& vten, std::vector<float >& data) {
47186 api::StorageBuffer staging_buffer (api::context (), api::kFloat , data.size ());
48187
49188 copy_ptr_to_staging (data.data (), staging_buffer, vten.gpu_nbytes ());
50189
51190 if (vten.storage_type () == api::StorageType::BUFFER) {
52- packing::record_nchw_to_buffer_op (
53- api::context (), staging_buffer.buffer (), vten, {}, VK_NULL_HANDLE);
191+ record_nchw_to_buffer_op (api::context (), staging_buffer.buffer (), vten);
54192 } else {
55- api::ShaderInfo compute_shader = packing::get_nchw_to_image_shader (vten);
56- packing::record_nchw_to_image_op (
57- api::context (),
58- compute_shader,
59- staging_buffer.buffer (),
60- vten,
61- {},
62- VK_NULL_HANDLE);
193+ record_nchw_to_image_op (api::context (), staging_buffer.buffer (), vten);
63194 }
64195}
65196
@@ -75,17 +206,9 @@ void extract_vtensor(vTensor& vten, std::vector<float>& data) {
75206 api::context (), api::kFloat , vten.gpu_numel ());
76207
77208 if (vten.storage_type () == api::StorageType::BUFFER) {
78- packing::record_buffer_to_nchw_op (
79- api::context (), vten, staging_buffer.buffer (), {}, VK_NULL_HANDLE);
209+ record_buffer_to_nchw_op (api::context (), vten, staging_buffer.buffer ());
80210 } else {
81- api::ShaderInfo compute_shader = packing::get_image_to_nchw_shader (vten);
82- packing::record_image_to_nchw_op (
83- api::context (),
84- compute_shader,
85- vten,
86- staging_buffer.buffer (),
87- {},
88- VK_NULL_HANDLE);
211+ record_image_to_nchw_op (api::context (), vten, staging_buffer.buffer ());
89212 }
90213
91214 api::VulkanFence fence = api::context ()->fences ().get_fence ();
@@ -208,14 +331,14 @@ TEST_F(VulkanComputeAPITest, texture_add_sanity_check) {
208331 std::fill (data_b.begin (), data_b.end (), 1 .5f );
209332
210333 // Add shader kernel
211- api::ShaderInfo kernel = arithmetic::get_shader (arithmetic::OpType::ADD );
334+ api::ShaderInfo kernel = VK_KERNEL (add );
212335
213336 // Fill input tensors
214337 fill_vtensor (a, data_a);
215338 fill_vtensor (b, data_b);
216339
217340 // a + b -> c
218- arithmetic::record_op (api::context (), kernel, a, b, c, 1 .0f );
341+ record_arithmetic_op (api::context (), kernel, a, b, c, 1 .0f );
219342
220343 // Extract output tensor
221344 std::vector<float > data_out (c.gpu_numel ());
@@ -244,7 +367,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) {
244367 std::vector<float > data_b (b.gpu_numel ());
245368 std::fill (data_b.begin (), data_b.end (), 1 .5f );
246369
247- api::ShaderInfo kernel = arithmetic::get_shader (arithmetic::OpType::ADD );
370+ api::ShaderInfo kernel = VK_KERNEL (add );
248371
249372 // Allocate memory at the last possible opportunity
250373 api::MemoryAllocation a_mem = allocate_memory_for (a);
@@ -260,7 +383,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) {
260383 fill_vtensor (a, data_a);
261384 fill_vtensor (b, data_b);
262385
263- arithmetic::record_op (api::context (), kernel, a, b, c, 1 .0f );
386+ record_arithmetic_op (api::context (), kernel, a, b, c, 1 .0f );
264387
265388 std::vector<float > data_c (c.gpu_numel ());
266389 extract_vtensor (c, data_c);
@@ -310,20 +433,20 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) {
310433 std::fill (data_d.begin (), data_d.end (), 1 .0f );
311434
312435 // Get shader kernel for add
313- api::ShaderInfo kernel = arithmetic::get_shader (arithmetic::OpType::ADD );
436+ api::ShaderInfo kernel = VK_KERNEL (add );
314437
315438 // First, fill a and b with data
316439 fill_vtensor (a, data_a);
317440 fill_vtensor (b, data_b);
318441
319442 // a + b -> c
320- arithmetic::record_op (api::context (), kernel, a, b, c, 1 .0f );
443+ record_arithmetic_op (api::context (), kernel, a, b, c, 1 .0f );
321444
322445 // Now d can be filled with data
323446 fill_vtensor (d, data_d);
324447
325448 // c + d -> e
326- arithmetic::record_op (api::context (), kernel, c, d, e, 1 .0f );
449+ record_arithmetic_op (api::context (), kernel, c, d, e, 1 .0f );
327450
328451 // Extract data from e
329452 std::vector<float > data_e (e.gpu_numel ());
0 commit comments