|
35 | 35 | all_types_and_complex_and, |
36 | 36 | all_types_and, |
37 | 37 | ) |
| 38 | +import torch.utils._pytree as pytree |
38 | 39 | import torch_xla |
39 | 40 | import torch_xla.core.xla_builder as xb |
40 | 41 | import torch_xla.core.xla_op_registry as xor |
@@ -2384,24 +2385,35 @@ def test_cummax_0_sized_dimension(self): |
2384 | 2385 |
|
2385 | 2386 | self.assertEqual(actual, expected) |
2386 | 2387 |
|
2387 | | - def test_conj(self): |
2388 | | - # Leave the factory out of the fallback count. |
2389 | | - tensor = torch.rand(2, 2, dtype=torch.complex64) |
2390 | | - |
| 2388 | + def _test_no_fallback(self, runf, args): |
2391 | 2389 | met.clear_all() |
2392 | 2390 |
|
2393 | 2391 | def run(device): |
2394 | | - return torch.conj(tensor.to(device)) |
| 2392 | + args_ = pytree.tree_map_only(torch.Tensor, |
| 2393 | + lambda t: t.clone().detach().to(device), |
| 2394 | + args) |
| 2395 | + return runf(*args_) |
2395 | 2396 |
|
2396 | 2397 | actual = run("cpu") |
2397 | 2398 | expected = run(xm.xla_device()) |
2398 | 2399 |
|
2399 | | - self.assertEqual( |
2400 | | - met.executed_fallback_ops(), [], |
2401 | | - message="expected no fallback operations.") |
| 2400 | + self.assertFalse( |
| 2401 | + met.executed_fallback_ops(), msg="expected no fallback operations.") |
2402 | 2402 | self.assertEqual( |
2403 | 2403 | actual, expected.cpu(), message="XLA results should match CPU results.") |
2404 | 2404 |
|
| 2405 | + def test_conj_no_fallback(self): |
| 2406 | + tensor = torch.rand(2, 2, dtype=torch.complex64) |
| 2407 | + self._test_no_fallback(torch.conj, (tensor,)) |
| 2408 | + |
| 2409 | + def test_isneginf_no_fallback(self): |
| 2410 | + t = torch.rand(10) |
| 2411 | + # Scale the tensor elements. |
| 2412 | + t = t * 100_000 |
| 2413 | + # Convert to a lower precision data-type so as to get a few infs. |
| 2414 | + t = t.to(torch.float16) |
| 2415 | + self._test_no_fallback(torch.isneginf, (t,)) |
| 2416 | + |
2405 | 2417 |
|
2406 | 2418 | class MNISTComparator(nn.Module): |
2407 | 2419 |
|
|
0 commit comments