Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,22 +245,22 @@ def fn_fallback(t):
cpu_res = fn_fallback(t)
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 8)
self.assertEqual(met.metric_data('CompileTime')[0], 3)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CompileTime and ExecuteTime only differs for this test_fallback_multiple_submodules unit test. With allows_single_node_partition=True, the partition is different and sees slightly different CompileTime and ExecuteTime metrics.

self.assertEqual(met.metric_data('ExecuteTime')[0], 10)

# Second tracing
met.clear_counters()
xla_dynamo_res_2 = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 10)
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)

# Verify that dynamo can handle different inputs
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
cpu_res_3 = fn_fallback(t * 3)
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 5)
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 15)


class DynamoTrainingBasicTest(unittest.TestCase):
Expand Down
11 changes: 9 additions & 2 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.utils.fuser_utils import topo_sort
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as metrics
Expand Down Expand Up @@ -421,10 +422,16 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
"call_function", "call_module", "call_method"
] and (node not in fallback_ops or node.target == operator.getitem)

# partition the model and exectue to collect inputs
# partition the model
supported_ops = XlaOperatorSupport()
partitioner = CapabilityBasedPartitioner(xla_model, supported_ops)
partitioner = CapabilityBasedPartitioner(xla_model, supported_ops, allows_single_node_partition=True)
partitions = partitioner.propose_partitions()

# propose_partitions() does not guarantee topolgical order, so sort it manually
for partition in partitions:
partition.nodes = topo_sort(partition.nodes)

# fuse partitions and exectue to collect inputs
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

Expand Down