@@ -458,12 +458,13 @@ void SyncLiveTensors(const std::string& device_str,
458458}
459459
460460void StepMarker (const std::string& device_str,
461- const std::vector<std::string>& devices, bool wait) {
461+ const std::vector<std::string>& devices, bool wait,
462+ bool reset_scope) {
462463 tsl::profiler::TraceMe activity (" StepMarker" ,
463464 tsl::profiler::TraceMeLevel::kInfo );
464465 torch::lazy::BackendDevice device = GetDeviceOrCurrent (device_str);
465466 XLAGraphExecutor::Get ()->SyncLiveTensorsGraph (&device, devices, wait);
466- XLAGraphExecutor::Get ()->MarkStep (device);
467+ XLAGraphExecutor::Get ()->MarkStep (device, reset_scope );
467468 bool debug_mode = runtime::sys_util::GetEnvBool (" PT_XLA_DEBUG" , false );
468469 if (TF_PREDICT_FALSE (debug_mode)) {
469470 std::string report = runtime::metrics::CreatePerformanceReport (
@@ -1649,11 +1650,12 @@ void InitXlaModuleBindings(py::module m) {
16491650 m.def (
16501651 " _xla_step_marker" ,
16511652 [](const std::string& device, const std::vector<std::string>& devices,
1652- bool wait) {
1653+ bool wait, bool reset_scope ) {
16531654 NoGilSection nogil;
1654- StepMarker (device, devices, wait);
1655+ StepMarker (device, devices, wait, reset_scope );
16551656 },
1656- py::arg (" device" ) = " " , py::arg (" devices" ), py::arg (" wait" ) = true );
1657+ py::arg (" device" ) = " " , py::arg (" devices" ), py::arg (" wait" ) = true ,
1658+ py::arg (" reset_scope" ) = true );
16571659 m.def (" _get_stablehlo" ,
16581660 [](const std::vector<at::Tensor>& tensors, const std::string& device,
16591661 const std::vector<std::string>& devices,
0 commit comments