diff --git a/test/test_devices.py b/test/test_devices.py index ff93f64a5c50..e1fc804736d8 100644 --- a/test/test_devices.py +++ b/test/test_devices.py @@ -4,14 +4,19 @@ import torch import torch_xla as xla import torch_xla.runtime as xr +import torch_xla.debug.metrics as met class TestDevices(parameterized.TestCase): - def setUpClass(): + @classmethod + def setUpClass(cls): xr.set_device_type('CPU') os.environ['CPU_NUM_DEVICES'] = '4' + def tearDown(self): + met.clear_metrics() + @parameterized.parameters((None, torch.device('xla:0')), (0, torch.device('xla:0')), (3, torch.device('xla:3'))) @@ -29,6 +34,12 @@ def test_real_devices(self): def test_device_count(self): self.assertEqual(xla.device_count(), 4) + def test_sync(self): + torch.ones((3, 3), device=xla.device()) + xla.sync() + + self.assertEqual(met.counter_value('MarkStep'), 1) + if __name__ == "__main__": absltest.main() diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 961f6a3217ed..141d7e3e5a76 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -45,3 +45,8 @@ def real_devices() -> List[str]: def device_count() -> int: """Returns number of addressable devices in the current process.""" return len(real_devices()) + + +def sync(): + """Launches all pending graph operations.""" + xm.mark_step()