Skip to content

Commit 7dfbef0

Browse files
authored
Fix runtime error when run dynamo with a profiler scope (#6913)
1 parent a170ffe commit 7dfbef0

File tree

6 files changed

+36
-13
lines changed

6 files changed

+36
-13
lines changed

test/dynamo/test_dynamo.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch_xla.utils.utils as xu
88
import torch_xla.debug.metrics as met
99
from torch_xla import runtime as xr
10+
import torch_xla.debug.profiler as xp
1011
import torch.optim as optim
1112
import torch.nn as nn
1213
import torch._dynamo as dynamo
@@ -91,6 +92,20 @@ def test_mark_step_after_dynamo(self):
9192
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])
9293

9394

95+
class DynamoProfilerTest(unittest.TestCase):
96+
97+
def dummy_fn(self, a):
98+
return torch.sin(a) + a
99+
100+
def test_dynamo_with_trace(self):
101+
dynamo_dummy = torch.compile(
102+
self.dummy_fn, backend="openxla", fullgraph=True)
103+
t = torch.randn(2, 3, 4, device=xm.xla_device())
104+
for i in range(10):
105+
with xp.Trace('build_graph'):
106+
t = dynamo_dummy(t)
107+
108+
94109
class DynamoInferenceBasicTest(unittest.TestCase):
95110

96111
@classmethod

torch_xla/core/dynamo_bridge.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,8 @@ def extract_internal(xla_model: torch.fx.GraphModule):
429429
for xla_arg in xla_model.xla_args:
430430
if isinstance(xla_arg, torch.Tensor):
431431
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
432-
xm.mark_step()
432+
# Don't reset the scope as we might be under some profiler trace scope.
433+
xm.mark_step(reset_scope=False)
433434
(xla_args_sharding_spec, args_and_out, graph_hash,
434435
arg_index_to_need_update_index, none_remover, graph_input_matcher,
435436
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
@@ -614,8 +615,9 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
614615
if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a):
615616
torch._functionalize_sync(a)
616617

617-
# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
618-
xm.mark_step()
618+
# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids.
619+
# Don't reset the scope as we might be under some profiler trace scope.
620+
xm.mark_step(reset_scope=False)
619621

620622
# Find tensor constructor nodes that create CPU tensors, and make
621623
# them create XLA tensors, where possible, instead. i.e. replace the

torch_xla/core/xla_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ def _run_step_closures():
10451045
return devctx
10461046

10471047

1048-
def mark_step(wait=False):
1048+
def mark_step(wait=False, reset_scope=True):
10491049
if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
10501050
print(
10511051
'torch_xla.core.xla_model::mark_step\n',
@@ -1054,7 +1054,8 @@ def mark_step(wait=False):
10541054
flush=True)
10551055
torch_xla._XLAC._xla_step_marker(
10561056
torch_xla._XLAC._xla_get_default_device(), [],
1057-
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait))
1057+
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait),
1058+
reset_scope=reset_scope)
10581059
# Only emit metrics from the first local device index, to avoid emitting the
10591060
# same values from different threads.
10601061
if is_master_ordinal():

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -458,12 +458,13 @@ void SyncLiveTensors(const std::string& device_str,
458458
}
459459

460460
void StepMarker(const std::string& device_str,
461-
const std::vector<std::string>& devices, bool wait) {
461+
const std::vector<std::string>& devices, bool wait,
462+
bool reset_scope) {
462463
tsl::profiler::TraceMe activity("StepMarker",
463464
tsl::profiler::TraceMeLevel::kInfo);
464465
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
465466
XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&device, devices, wait);
466-
XLAGraphExecutor::Get()->MarkStep(device);
467+
XLAGraphExecutor::Get()->MarkStep(device, reset_scope);
467468
bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false);
468469
if (TF_PREDICT_FALSE(debug_mode)) {
469470
std::string report = runtime::metrics::CreatePerformanceReport(
@@ -1649,11 +1650,12 @@ void InitXlaModuleBindings(py::module m) {
16491650
m.def(
16501651
"_xla_step_marker",
16511652
[](const std::string& device, const std::vector<std::string>& devices,
1652-
bool wait) {
1653+
bool wait, bool reset_scope) {
16531654
NoGilSection nogil;
1654-
StepMarker(device, devices, wait);
1655+
StepMarker(device, devices, wait, reset_scope);
16551656
},
1656-
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
1657+
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true,
1658+
py::arg("reset_scope") = true);
16571659
m.def("_get_stablehlo",
16581660
[](const std::vector<at::Tensor>& tensors, const std::string& device,
16591661
const std::vector<std::string>& devices,

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,17 @@ void XLAGraphExecutor::SyncLiveTensorsGraph(
404404
SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true);
405405
}
406406

407-
void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
407+
void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device,
408+
bool reset_scope) {
408409
// TODO(jwtan): Replace this with TORCH_LAZY_COUNTER. We need MarkStep to
409410
// remain as XLA_COUNTER to support
410411
// runtime::metrics::CreatePerformanceReport(). For more information, see
411412
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
412413
XLA_COUNTER("MarkStep", 1);
413414
DeviceContextArena::Get()->MarkStep(device);
414-
torch::lazy::ScopePusher::ResetScopes();
415+
if (reset_scope) {
416+
torch::lazy::ScopePusher::ResetScopes();
417+
}
415418
ResetTrimCounter();
416419
}
417420

torch_xla/csrc/xla_graph_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
134134
// Marks an execution step, which allows the tensor framework to understand
135135
// the computation boundaries.
136136
// Override to use our own DeviceContextArena.
137-
void MarkStep(const torch::lazy::BackendDevice& device) final;
137+
void MarkStep(const torch::lazy::BackendDevice& device, bool reset_scope);
138138

139139
// Waits for all the outstanding operations on all the supplied devices.
140140
// If devices is empty, the wait will happen for all local devices.

0 commit comments

Comments
 (0)