Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion test/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import torch
import torch_xla as xla
import torch_xla.runtime as xr
import torch_xla.debug.metrics as met


class TestDevices(parameterized.TestCase):

def setUpClass():
@classmethod
def setUpClass(cls):
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'

def tearDown(self):
met.clear_metrics()

@parameterized.parameters((None, torch.device('xla:0')),
(0, torch.device('xla:0')),
(3, torch.device('xla:3')))
Expand All @@ -29,6 +34,12 @@ def test_real_devices(self):
def test_device_count(self):
self.assertEqual(xla.device_count(), 4)

def test_sync(self):
torch.ones((3, 3), device=xla.device())
xla.sync()

self.assertEqual(met.counter_value('MarkStep'), 1)


if __name__ == "__main__":
absltest.main()
5 changes: 5 additions & 0 deletions torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ def real_devices() -> List[str]:
def device_count() -> int:
"""Returns number of addressable devices in the current process."""
return len(real_devices())


def sync():
"""Launches all pending graph operations."""
xm.mark_step()