Skip to content

Commit e82231f

Browse files
authored
[CherryPick] Lower isneginf(). (#8926)
1 parent cbfda4c commit e82231f

File tree

6 files changed

+38
-8
lines changed

6 files changed

+38
-8
lines changed

codegen/xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ full_codegen:
6060
- hardswish_backward
6161
- inverse
6262
- isnan
63+
- isneginf
6364
- leaky_relu
6465
- le.Scalar
6566
- le.Tensor

test/test_operations.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
all_types_and_complex_and,
3636
all_types_and,
3737
)
38+
import torch.utils._pytree as pytree
3839
import torch_xla
3940
import torch_xla.core.xla_builder as xb
4041
import torch_xla.core.xla_op_registry as xor
@@ -2384,24 +2385,35 @@ def test_cummax_0_sized_dimension(self):
23842385

23852386
self.assertEqual(actual, expected)
23862387

2387-
def test_conj(self):
2388-
# Leave the factory out of the fallback count.
2389-
tensor = torch.rand(2, 2, dtype=torch.complex64)
2390-
2388+
def _test_no_fallback(self, runf, args):
23912389
met.clear_all()
23922390

23932391
def run(device):
2394-
return torch.conj(tensor.to(device))
2392+
args_ = pytree.tree_map_only(torch.Tensor,
2393+
lambda t: t.clone().detach().to(device),
2394+
args)
2395+
return runf(*args_)
23952396

23962397
actual = run("cpu")
23972398
expected = run(xm.xla_device())
23982399

2399-
self.assertEqual(
2400-
met.executed_fallback_ops(), [],
2401-
message="expected no fallback operations.")
2400+
self.assertFalse(
2401+
met.executed_fallback_ops(), msg="expected no fallback operations.")
24022402
self.assertEqual(
24032403
actual, expected.cpu(), message="XLA results should match CPU results.")
24042404

2405+
def test_conj_no_fallback(self):
2406+
tensor = torch.rand(2, 2, dtype=torch.complex64)
2407+
self._test_no_fallback(torch.conj, (tensor,))
2408+
2409+
def test_isneginf_no_fallback(self):
2410+
t = torch.rand(10)
2411+
# Scale the tensor elements.
2412+
t = t * 100_000
2413+
# Convert to a lower precision data-type so as to get a few infs.
2414+
t = t.to(torch.float16)
2415+
self._test_no_fallback(torch.isneginf, (t,))
2416+
24052417

24062418
class MNISTComparator(nn.Module):
24072419

test/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def get_allowed_ops_map(
127127
AllowedOpInfoEntry('imag'),
128128
AllowedOpInfoEntry('inverse'),
129129
AllowedOpInfoEntry('isin'),
130+
AllowedOpInfoEntry('isneginf'),
130131
AllowedOpInfoEntry('le'),
131132
AllowedOpInfoEntry('linalg.cholesky'),
132133
AllowedOpInfoEntry('linalg.cholesky_ex'),

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "torch_xla/csrc/xla_lower_util.h"
1414
#include "xla/client/lib/math.h"
1515
#include "xla/client/lib/matrix.h"
16+
#include "xla/hlo/builder/lib/constants.h"
1617
#include "xla/hlo/builder/lib/logdet.h"
1718

1819
namespace torch_xla {
@@ -531,6 +532,13 @@ torch_xla::XlaOpVector Isnan::Lower(LoweringContext* loctx) const {
531532
return ReturnOp(xla::IsNan(xla_input), loctx);
532533
}
533534

535+
torch_xla::XlaOpVector Isneginf::Lower(LoweringContext* loctx) const {
536+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
537+
return ReturnOp(xla::Eq(input, xla::MinValue(input.builder(),
538+
XlaHelpers::TypeOfXlaOp(input))),
539+
loctx);
540+
}
541+
534542
torch_xla::XlaOpVector LeakyRelu::Lower(LoweringContext* loctx) const {
535543
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
536544
xla::XlaOp negative_slope = loctx->GetOutputOp(operand(1));

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,12 @@ xla::Shape IsnanOutputShape(const torch::lazy::Value& input) {
590590
return isnan_shape;
591591
}
592592

593+
xla::Shape IsneginfOutputShape(const torch::lazy::Value& input) {
594+
xla::Shape shape(GetXlaShape(input));
595+
shape.set_element_type(xla::PRED);
596+
return shape;
597+
}
598+
593599
xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
594600
const torch::lazy::Value& negative_slope) {
595601
auto lower_for_shape_fn =

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ xla::Shape InverseOutputShape(const torch::lazy::Value& input);
174174

175175
xla::Shape IsnanOutputShape(const torch::lazy::Value& input);
176176

177+
xla::Shape IsneginfOutputShape(const torch::lazy::Value& input);
178+
177179
xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
178180
const torch::lazy::Value& negative_slope);
179181

0 commit comments

Comments
 (0)