diff --git a/kernels/portable/cpu/op_split_with_sizes_copy.cpp b/kernels/portable/cpu/op_split_with_sizes_copy.cpp index 828b14b24ce..7d1b485e7a4 100644 --- a/kernels/portable/cpu/op_split_with_sizes_copy.cpp +++ b/kernels/portable/cpu/op_split_with_sizes_copy.cpp @@ -55,8 +55,7 @@ void split_with_sizes_copy_out( target_out_sizes[dim] = static_cast(split_sizes[i]); ET_KERNEL_CHECK( ctx, - tensor_is_broadcastable_to( - {target_out_sizes, target_out_ndim}, out[i].sizes()), + resize_tensor(out[i], {target_out_sizes, target_out_ndim}) == Error::Ok, InvalidArgument, ); } diff --git a/kernels/test/op_split_with_sizes_copy_test.cpp b/kernels/test/op_split_with_sizes_copy_test.cpp index f67789c3561..91ef94af653 100644 --- a/kernels/test/op_split_with_sizes_copy_test.cpp +++ b/kernels/test/op_split_with_sizes_copy_test.cpp @@ -27,66 +27,88 @@ class OpSplitWithSizesCopyOutTest : public OperatorTest { return torch::executor::aten::split_with_sizes_copy_outf( context_, self, split_sizes, dim, out); } + + void test_tensor_shape_dynamism(exec_aten::TensorShapeDynamism dynamism) { + torch::executor::testing::TensorFactory + tfFloat; + + exec_aten::Tensor self = tfFloat.make( + {2, 6, 3}, + {-31.25, -92.75, -39.75, -3.25, 53.875, 88.25, -0.625, -1.125, + 14.75, 42.0, 89.875, -21.125, -8.0, -64.125, 23.0, 37.0, + 46.125, -83.25, -58.125, 19.625, -71.125, 64.75, -1.375, -83.5, + -61.375, 13.125, 28.625, -94.0, -67.0, -8.625, -88.875, -79.125, + 0.375, -61.375, 65.0, -99.375}); + ::std::vector split_sizes_vec = {3, 1, 2}; + exec_aten::ArrayRef split_sizes = exec_aten::ArrayRef( + split_sizes_vec.data(), split_sizes_vec.size()); + int64_t dim = 1; + + ::std::vector out_vec; + if (dynamism == exec_aten::TensorShapeDynamism::STATIC) { + out_vec = { + tfFloat.zeros({2, 3, 3}), + tfFloat.zeros({2, 1, 3}), + tfFloat.zeros({2, 2, 3})}; + } else { // dynamism == exec_aten::TensorShapeDynamism::DYNAMIC_BOUND + out_vec = { + tfFloat.zeros( + {2, 3, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND), + tfFloat.zeros( + {2, 1, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND), + tfFloat.zeros( + {2, 2, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND)}; + } + + exec_aten::TensorList out = + exec_aten::TensorList(out_vec.data(), out_vec.size()); + ::std::vector out_expected_vec = { + tfFloat.make( + {2, 3, 3}, + {-31.25, + -92.75, + -39.75, + -3.25, + 53.875, + 88.25, + -0.625, + -1.125, + 14.75, + -58.125, + 19.625, + -71.125, + 64.75, + -1.375, + -83.5, + -61.375, + 13.125, + 28.625}), + tfFloat.make({2, 1, 3}, {42.0, 89.875, -21.125, -94.0, -67.0, -8.625}), + tfFloat.make( + {2, 2, 3}, + {-8.0, + -64.125, + 23.0, + 37.0, + 46.125, + -83.25, + -88.875, + -79.125, + 0.375, + -61.375, + 65.0, + -99.375})}; + exec_aten::TensorList out_expected = + exec_aten::TensorList(out_expected_vec.data(), out_expected_vec.size()); + op_split_with_sizes_copy_out(self, split_sizes, dim, out); + EXPECT_TENSOR_LISTS_CLOSE(out, out_expected); + } }; TEST_F(OpSplitWithSizesCopyOutTest, SanityCheckDim1) { - torch::executor::testing::TensorFactory tfFloat; + test_tensor_shape_dynamism(exec_aten::TensorShapeDynamism::STATIC); +} - exec_aten::Tensor self = tfFloat.make( - {2, 6, 3}, - {-31.25, -92.75, -39.75, -3.25, 53.875, 88.25, -0.625, -1.125, - 14.75, 42.0, 89.875, -21.125, -8.0, -64.125, 23.0, 37.0, - 46.125, -83.25, -58.125, 19.625, -71.125, 64.75, -1.375, -83.5, - -61.375, 13.125, 28.625, -94.0, -67.0, -8.625, -88.875, -79.125, - 0.375, -61.375, 65.0, -99.375}); - ::std::vector split_sizes_vec = {3, 1, 2}; - exec_aten::ArrayRef split_sizes = exec_aten::ArrayRef( - split_sizes_vec.data(), split_sizes_vec.size()); - int64_t dim = 1; - ::std::vector out_vec = { - tfFloat.zeros({2, 3, 3}), - tfFloat.zeros({2, 1, 3}), - tfFloat.zeros({2, 2, 3})}; - exec_aten::TensorList out = - exec_aten::TensorList(out_vec.data(), out_vec.size()); - ::std::vector out_expected_vec = { - tfFloat.make( - {2, 3, 3}, - {-31.25, - -92.75, - -39.75, - -3.25, - 53.875, - 88.25, - -0.625, - -1.125, - 14.75, - -58.125, - 19.625, - -71.125, - 64.75, - -1.375, - -83.5, - -61.375, - 13.125, - 28.625}), - tfFloat.make({2, 1, 3}, {42.0, 89.875, -21.125, -94.0, -67.0, -8.625}), - tfFloat.make( - {2, 2, 3}, - {-8.0, - -64.125, - 23.0, - 37.0, - 46.125, - -83.25, - -88.875, - -79.125, - 0.375, - -61.375, - 65.0, - -99.375})}; - exec_aten::TensorList out_expected = - exec_aten::TensorList(out_expected_vec.data(), out_expected_vec.size()); - op_split_with_sizes_copy_out(self, split_sizes, dim, out); - EXPECT_TENSOR_LISTS_CLOSE(out, out_expected); +TEST_F(OpSplitWithSizesCopyOutTest, DynamicShape) { + test_tensor_shape_dynamism(exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); }