Skip to content

Commit 3d8faf9

Browse files
lsy323ManfeiBai
authored andcommitted
Enable lowering for upsample_bilinear2d with scale factor (#4464)
* Enable lowering for upsample_bilinear2d with scale factor * fix linter * add shape validation
1 parent ab6686b commit 3d8faf9

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4311,6 +4311,47 @@ TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2D) {
43114311
cpp_test::GetIgnoredCounters());
43124312
}
43134313

4314+
TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DWithScale) {
4315+
struct ImageInfo {
4316+
int batch_size;
4317+
int h;
4318+
int w;
4319+
int chans;
4320+
double scale_h;
4321+
double scale_w;
4322+
};
4323+
4324+
/* clang-format off */
4325+
std::vector<ImageInfo> inputs = {
4326+
{/*batch_size=*/2, /*h=*/5, /*w=*/5, /*chans=*/2, /*scale_h*/8.0/5, /*scale_w*/8.0/5},
4327+
{/*batch_size=*/2, /*h=*/1335, /*w=*/1335, /*chans=*/3, /*scale_h*/255.0/1335, /*scale_w*/255.0/1335},
4328+
{/*batch_size=*/2, /*h=*/255, /*w=*/255, /*chans=*/3, /*scale_h*/1335.0/255, /*scale_w*/1335.0/255},
4329+
{/*batch_size=*/2, /*h=*/254, /*w=*/243, /*chans=*/3, /*scale_h*/784.0/254, /*scale_w*/214.0/243}
4330+
};
4331+
/* clang-format on */
4332+
4333+
for (const auto& img_info : inputs) {
4334+
for (bool align_corners : {true, false}) {
4335+
torch::Tensor input = torch::rand(
4336+
{img_info.batch_size, img_info.chans, img_info.h, img_info.w},
4337+
torch::TensorOptions(torch::kFloat));
4338+
ForEachDevice([&](const torch::Device& device) {
4339+
torch::Tensor xla_input = CopyToDevice(input, device);
4340+
torch::Tensor result = torch::upsample_bilinear2d(
4341+
input, c10::nullopt, align_corners,
4342+
at::ArrayRef<double>{img_info.scale_h, img_info.scale_w});
4343+
torch::Tensor xla_result = torch::upsample_bilinear2d(
4344+
xla_input, c10::nullopt, align_corners,
4345+
at::ArrayRef<double>{img_info.scale_h, img_info.scale_w});
4346+
AllClose(result, xla_result, /*rtol=*/1e-4, /*atol=*/1e-4);
4347+
});
4348+
}
4349+
}
4350+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
4351+
ExpectCounterChanged("xla::upsample_bilinear2d",
4352+
cpp_test::GetIgnoredCounters());
4353+
}
4354+
43144355
TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackward) {
43154356
int batch_size = 2;
43164357
int h = 5;

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,16 +2904,21 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d(
29042904
c10::optional<double> scales_h, c10::optional<double> scales_w) {
29052905
TORCH_LAZY_FN_COUNTER("xla::");
29062906
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
2907+
absl::Span<const int64_t> input_dims =
2908+
self_tensor->shape().get().dimensions();
2909+
std::vector<int64_t> scaled_output_size =
2910+
torch::lazy::ToVector<int64_t>(output_size);
29072911
if ((scales_h && *scales_h != 1.0) || (scales_w && *scales_w != 1.0)) {
2908-
return at::native::call_fallback_fn<
2909-
&xla_cpu_fallback, ATEN_OP(upsample_bilinear2d)>::call(self,
2910-
output_size,
2911-
align_corners,
2912-
scales_h,
2913-
scales_w);
2912+
scaled_output_size = GetOutputSizeWithScale(input_dims, scales_h, scales_w,
2913+
scaled_output_size);
2914+
if (!output_size.empty()) {
2915+
XLA_CHECK(scaled_output_size.at(0) == output_size.at(0) &&
2916+
scaled_output_size.at(1) == output_size.at(1))
2917+
<< "Inferred output size and output_size from upstream are different";
2918+
}
29142919
}
29152920
return bridge::AtenFromXlaTensor(tensor_methods::upsample_bilinear2d(
2916-
self_tensor, torch::lazy::ToVector<int64_t>(output_size), align_corners));
2921+
self_tensor, scaled_output_size, align_corners));
29172922
}
29182923

29192924
at::Tensor XLANativeFunctions::upsample_bilinear2d_backward(

0 commit comments

Comments
 (0)