@@ -4311,6 +4311,47 @@ TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2D) {
43114311 cpp_test::GetIgnoredCounters ());
43124312}
43134313
4314+ TEST_F (AtenXlaTensorTest, TestUpsampleBilinear2DWithScale) {
4315+ struct ImageInfo {
4316+ int batch_size;
4317+ int h;
4318+ int w;
4319+ int chans;
4320+ double scale_h;
4321+ double scale_w;
4322+ };
4323+
4324+ /* clang-format off */
4325+ std::vector<ImageInfo> inputs = {
4326+ {/* batch_size=*/ 2 , /* h=*/ 5 , /* w=*/ 5 , /* chans=*/ 2 , /* scale_h*/ 8.0 /5 , /* scale_w*/ 8.0 /5 },
4327+ {/* batch_size=*/ 2 , /* h=*/ 1335 , /* w=*/ 1335 , /* chans=*/ 3 , /* scale_h*/ 255.0 /1335 , /* scale_w*/ 255.0 /1335 },
4328+ {/* batch_size=*/ 2 , /* h=*/ 255 , /* w=*/ 255 , /* chans=*/ 3 , /* scale_h*/ 1335.0 /255 , /* scale_w*/ 1335.0 /255 },
4329+ {/* batch_size=*/ 2 , /* h=*/ 254 , /* w=*/ 243 , /* chans=*/ 3 , /* scale_h*/ 784.0 /254 , /* scale_w*/ 214.0 /243 }
4330+ };
4331+ /* clang-format on */
4332+
4333+ for (const auto & img_info : inputs) {
4334+ for (bool align_corners : {true , false }) {
4335+ torch::Tensor input = torch::rand (
4336+ {img_info.batch_size , img_info.chans , img_info.h , img_info.w },
4337+ torch::TensorOptions (torch::kFloat ));
4338+ ForEachDevice ([&](const torch::Device& device) {
4339+ torch::Tensor xla_input = CopyToDevice (input, device);
4340+ torch::Tensor result = torch::upsample_bilinear2d (
4341+ input, c10::nullopt , align_corners,
4342+ at::ArrayRef<double >{img_info.scale_h , img_info.scale_w });
4343+ torch::Tensor xla_result = torch::upsample_bilinear2d (
4344+ xla_input, c10::nullopt , align_corners,
4345+ at::ArrayRef<double >{img_info.scale_h , img_info.scale_w });
4346+ AllClose (result, xla_result, /* rtol=*/ 1e-4 , /* atol=*/ 1e-4 );
4347+ });
4348+ }
4349+ }
4350+ ExpectCounterNotChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
4351+ ExpectCounterChanged (" xla::upsample_bilinear2d" ,
4352+ cpp_test::GetIgnoredCounters ());
4353+ }
4354+
43144355TEST_F (AtenXlaTensorTest, TestUpsampleBilinear2DBackward) {
43154356 int batch_size = 2 ;
43164357 int h = 5 ;
0 commit comments