Skip to content

Commit b58327e

Browse files
authored
Add topological sorting to dynamo partitions (#5472)
* Add topological sorting to dynamo partitions * Run linter * Update unit tests to include more in-place ops
1 parent 6270cba commit b58327e

File tree

2 files changed

+40
-22
lines changed

2 files changed

+40
-22
lines changed

test/dynamo/test_dynamo.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,33 +101,43 @@ def __init__(self, device=None):
101101
super().__init__()
102102
self.self_tensor = torch.zeros((5, 3), device=device)
103103

104-
def forward(self, index, copy_tensor, input_tensor):
104+
def copy_(self, index, copy_tensor):
105105
self.self_tensor.index_copy_(0, index, copy_tensor)
106+
107+
def add_(self, index, other_tensor):
108+
self.self_tensor.add_(other_tensor)
109+
110+
def abs_(self, index, other_tensor):
111+
self.self_tensor.abs_()
112+
113+
def forward(self, index, copy_tensor, input_tensor, op_name):
114+
getattr(self, op_name)(index, copy_tensor)
106115
output = input_tensor + self.self_tensor
107116
return output
108117

109118
torch._dynamo.reset()
110119
met.clear_counters()
111120
met.clear_all()
112121
device = xm.xla_device()
113-
input_tensor = torch.ones(3)
114-
copy_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
115-
dtype=torch.float)
116-
index = torch.tensor([0, 4, 2])
117-
xla_input_tensor = input_tensor.to(device)
118-
xla_copy_tensor = copy_tensor.to(device)
119-
xla_index = index.to(device)
120122

121123
cpu_model = TestModel()
122-
res_cpu = cpu_model.forward(index, copy_tensor, input_tensor)
123-
124124
xla_model = TestModel(device).to(device)
125125
compiled_model = torch.compile(xla_model, backend='openxla')
126-
res_xla_dynamo = compiled_model.forward(xla_index, xla_copy_tensor,
127-
xla_input_tensor)
128126

129-
self.assertIn('xla::index_copy', met.counter_names())
130-
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
127+
input_tensor = torch.ones(3)
128+
copy_tensor = torch.rand(5, 3)
129+
index = torch.tensor([0, 4, 2, 1, 3])
130+
xla_input_tensor = input_tensor.to(device)
131+
xla_copy_tensor = copy_tensor.to(device)
132+
xla_index = index.to(device)
133+
134+
in_place_ops = ['copy_', 'add_', 'abs_']
135+
for in_place_op in in_place_ops:
136+
res_cpu = cpu_model.forward(
137+
index, copy_tensor, input_tensor, op_name=in_place_op)
138+
res_xla_dynamo = compiled_model.forward(
139+
xla_index, xla_copy_tensor, xla_input_tensor, op_name=in_place_op)
140+
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
131141

132142
def test_simple_model_with_different_input_shape(self):
133143
met.clear_counters()
@@ -245,22 +255,22 @@ def fn_fallback(t):
245255
cpu_res = fn_fallback(t)
246256
xla_dynamo_res = dynamo_fn(t_xla)
247257
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
248-
self.assertEqual(met.metric_data('CompileTime')[0], 4)
249-
self.assertEqual(met.metric_data('ExecuteTime')[0], 8)
258+
self.assertEqual(met.metric_data('CompileTime')[0], 3)
259+
self.assertEqual(met.metric_data('ExecuteTime')[0], 10)
250260

251261
# Second tracing
252262
met.clear_counters()
253263
xla_dynamo_res_2 = dynamo_fn(t_xla)
254264
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
255-
self.assertEqual(met.metric_data('CompileTime')[0], 4)
256-
self.assertEqual(met.metric_data('ExecuteTime')[0], 10)
265+
self.assertEqual(met.metric_data('CompileTime')[0], 3)
266+
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)
257267

258268
# Verify that dynamo can handle different inputs
259269
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
260270
cpu_res_3 = fn_fallback(t * 3)
261271
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
262-
self.assertEqual(met.metric_data('CompileTime')[0], 5)
263-
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)
272+
self.assertEqual(met.metric_data('CompileTime')[0], 4)
273+
self.assertEqual(met.metric_data('ExecuteTime')[0], 15)
264274

265275

266276
class DynamoTrainingBasicTest(unittest.TestCase):

torch_xla/core/dynamo_bridge.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
13+
from torch.fx.passes.utils.fuser_utils import topo_sort
1314
import torch_xla
1415
import torch_xla.core.xla_model as xm
1516
import torch_xla.debug.metrics as metrics
@@ -421,10 +422,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
421422
"call_function", "call_module", "call_method"
422423
] and (node not in fallback_ops or node.target == operator.getitem)
423424

424-
# partition the model and exectue to collect inputs
425+
# partition the model
425426
supported_ops = XlaOperatorSupport()
426-
partitioner = CapabilityBasedPartitioner(xla_model, supported_ops)
427+
partitioner = CapabilityBasedPartitioner(
428+
xla_model, supported_ops, allows_single_node_partition=True)
427429
partitions = partitioner.propose_partitions()
430+
431+
# propose_partitions() does not guarantee topolgical order, so sort it manually
432+
for partition in partitions:
433+
partition.nodes = topo_sort(partition.nodes)
434+
435+
# fuse partitions and exectue to collect inputs
428436
partitioned_graph = partitioner.fuse_partitions(partitions)
429437
InputCollector(partitioned_graph).run(*xla_args)
430438

0 commit comments

Comments
 (0)