Skip to content

Commit 640019b

Browse files
JackCaoGyitongh
authored andcommitted
Fix runtime error when run dynamo with a profiler scope (pytorch#6913)
1 parent 32faea0 commit 640019b

File tree

6 files changed

+66
-13
lines changed

6 files changed

+66
-13
lines changed

test/dynamo/test_dynamo.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch_xla.utils.utils as xu
88
import torch_xla.debug.metrics as met
99
from torch_xla import runtime as xr
10+
import torch_xla.debug.profiler as xp
1011
import torch.optim as optim
1112
import torch.nn as nn
1213
import torch._dynamo as dynamo
@@ -61,6 +62,50 @@ def test_random_op_different_result_each_run(self):
6162
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))
6263

6364

65+
class DynamoLTCInteractionTest(unittest.TestCase):
66+
67+
def index_copy_inplace(self, cache, update_indices, xk):
68+
cache.index_copy_(0, update_indices, xk)
69+
70+
def test_mark_step_after_dynamo(self):
71+
cache_len = 512
72+
kv_heads = 8
73+
head_dim = 128
74+
running = 16
75+
76+
device = xm.xla_device()
77+
cache = torch.rand((cache_len, kv_heads, head_dim)).to(device)
78+
update_indices = torch.randint(
79+
0, cache_len, (running,), dtype=torch.long).to(device)
80+
xk = torch.rand((running, kv_heads, head_dim)).to(device)
81+
82+
dynamo_index_copy_inplace = torch.compile(
83+
self.index_copy_inplace, backend="openxla", fullgraph=True)
84+
met.clear_all()
85+
for i in range(10):
86+
dynamo_index_copy_inplace(cache, update_indices, xk)
87+
xm.wait_device_ops()
88+
current_execute_time = met.metric_data('ExecuteTime')[0]
89+
# This mark_step should be a no-op and don't trigger additional execution.
90+
xm.mark_step()
91+
xm.wait_device_ops()
92+
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])
93+
94+
95+
class DynamoProfilerTest(unittest.TestCase):
96+
97+
def dummy_fn(self, a):
98+
return torch.sin(a) + a
99+
100+
def test_dynamo_with_trace(self):
101+
dynamo_dummy = torch.compile(
102+
self.dummy_fn, backend="openxla", fullgraph=True)
103+
t = torch.randn(2, 3, 4, device=xm.xla_device())
104+
for i in range(10):
105+
with xp.Trace('build_graph'):
106+
t = dynamo_dummy(t)
107+
108+
64109
class DynamoInferenceBasicTest(unittest.TestCase):
65110

66111
@classmethod

torch_xla/core/dynamo_bridge.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,8 @@ def extract_internal(xla_model: torch.fx.GraphModule):
429429
for xla_arg in xla_model.xla_args:
430430
if isinstance(xla_arg, torch.Tensor):
431431
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
432-
xm.mark_step()
432+
# Don't reset the scope as we might be under some profiler trace scope.
433+
xm.mark_step(reset_scope=False)
433434
(xla_args_sharding_spec, args_and_out, graph_hash,
434435
arg_index_to_need_update_index, none_remover, graph_input_matcher,
435436
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
@@ -614,8 +615,9 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
614615
if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a):
615616
torch._functionalize_sync(a)
616617

617-
# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
618-
xm.mark_step()
618+
# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids.
619+
# Don't reset the scope as we might be under some profiler trace scope.
620+
xm.mark_step(reset_scope=False)
619621

620622
# Find tensor constructor nodes that create CPU tensors, and make
621623
# them create XLA tensors, where possible, instead. i.e. replace the

torch_xla/core/xla_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ def _run_step_closures():
10461046
return devctx
10471047

10481048

1049-
def mark_step(wait=False):
1049+
def mark_step(wait=False, reset_scope=True):
10501050
if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
10511051
print(
10521052
'torch_xla.core.xla_model::mark_step\n',
@@ -1055,7 +1055,8 @@ def mark_step(wait=False):
10551055
flush=True)
10561056
torch_xla._XLAC._xla_step_marker(
10571057
torch_xla._XLAC._xla_get_default_device(), [],
1058-
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait))
1058+
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait),
1059+
reset_scope=reset_scope)
10591060
# Only emit metrics from the first local device index, to avoid emitting the
10601061
# same values from different threads.
10611062
if is_master_ordinal():

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,12 +447,13 @@ void SyncLiveTensors(const std::string& device_str,
447447
}
448448

449449
void StepMarker(const std::string& device_str,
450-
const std::vector<std::string>& devices, bool wait) {
450+
const std::vector<std::string>& devices, bool wait,
451+
bool reset_scope) {
451452
tsl::profiler::TraceMe activity("StepMarker",
452453
tsl::profiler::TraceMeLevel::kInfo);
453454
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
454455
XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&device, devices, wait);
455-
XLAGraphExecutor::Get()->MarkStep(device);
456+
XLAGraphExecutor::Get()->MarkStep(device, reset_scope);
456457
bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false);
457458
if (TF_PREDICT_FALSE(debug_mode)) {
458459
std::string report = runtime::metrics::CreatePerformanceReport(
@@ -1698,11 +1699,12 @@ void InitXlaModuleBindings(py::module m) {
16981699
m.def(
16991700
"_xla_step_marker",
17001701
[](const std::string& device, const std::vector<std::string>& devices,
1701-
bool wait) {
1702+
bool wait, bool reset_scope) {
17021703
NoGilSection nogil;
1703-
StepMarker(device, devices, wait);
1704+
StepMarker(device, devices, wait, reset_scope);
17041705
},
1705-
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
1706+
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true,
1707+
py::arg("reset_scope") = true);
17061708
m.def("_get_stablehlo",
17071709
[](const std::vector<at::Tensor>& tensors, const std::string& device,
17081710
const std::vector<std::string>& devices,

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,17 @@ void XLAGraphExecutor::SyncLiveTensorsGraph(
404404
SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true);
405405
}
406406

407-
void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
407+
void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device,
408+
bool reset_scope) {
408409
// TODO(jwtan): Replace this with TORCH_LAZY_COUNTER. We need MarkStep to
409410
// remain as XLA_COUNTER to support
410411
// runtime::metrics::CreatePerformanceReport(). For more information, see
411412
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
412413
XLA_COUNTER("MarkStep", 1);
413414
DeviceContextArena::Get()->MarkStep(device);
414-
torch::lazy::ScopePusher::ResetScopes();
415+
if (reset_scope) {
416+
torch::lazy::ScopePusher::ResetScopes();
417+
}
415418
ResetTrimCounter();
416419
}
417420

torch_xla/csrc/xla_graph_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
134134
// Marks an execution step, which allows the tensor framework to understand
135135
// the computation boundaries.
136136
// Override to use our own DeviceContextArena.
137-
void MarkStep(const torch::lazy::BackendDevice& device) final;
137+
void MarkStep(const torch::lazy::BackendDevice& device, bool reset_scope);
138138

139139
// Waits for all the outstanding operations on all the supplied devices.
140140
// If devices is empty, the wait will happen for all local devices.

0 commit comments

Comments
 (0)