Skip to content

Commit f68f6f2

Browse files
ko3n1gclaude
andauthored
feat: poll and print SLURM job estimated start time while pending (#464)
* feat: poll and print SLURM job estimated start time while pending When a SLURM job is submitted and sits in the queue, there is no feedback about when it is expected to start. This adds a background daemon thread per job that polls `squeue --start` every 30 seconds and prints the estimated start time to stdout, stopping automatically once the job leaves the pending queue. Key details: - `_poll_job_start_time`: new method guards against None stdout, non-zero return codes, and array-job multi-line output (prints only first line) - Thread is started in `schedule()` and stopped in `_cancel_existing()` and `close()`; duplicate job_id (retry) stops the old thread first - 11 new TDD tests cover all edge cases from the plan Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: oliver könig <okoenig@nvidia.com> * feat: print current timestamp in SLURM job start time poll output Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: oliver könig <okoenig@nvidia.com> * feat: exponential backoff for SLURM job start time polling Replace fixed 30s interval with exponential backoff (30s base, 2x factor, capped at 15min) to reduce unnecessary polling for long-pending jobs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: oliver könig <okoenig@nvidia.com> --------- Signed-off-by: oliver könig <okoenig@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 0c556ca commit f68f6f2

2 files changed

Lines changed: 280 additions & 1 deletion

File tree

nemo_run/run/torchx_backend/schedulers/slurm.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import json
2323
import logging
2424
import os
25+
import threading
2526
import time
2627
from dataclasses import asdict
2728
from datetime import datetime
@@ -78,6 +79,8 @@ def __init__(
7879
super().__init__(session_name)
7980
self.experiment = experiment
8081
self._consecutive_sacct_failures: dict[str, int] = {}
82+
self._start_time_threads: dict[str, threading.Thread] = {}
83+
self._start_time_stop_events: dict[str, threading.Event] = {}
8184

8285
# TODO: Move this into the SlurmExecutor
8386
def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel):
@@ -190,6 +193,41 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
190193

191194
return AppDryRunInfo(req, repr)
192195

196+
def _poll_job_start_time(
197+
self, job_id: str, tunnel: Tunnel, stop_event: threading.Event
198+
) -> None:
199+
attempt = 0
200+
while not stop_event.is_set():
201+
try:
202+
result = tunnel.run(
203+
f"squeue --start --noheader -j {job_id} -o '%i|%S|%T'",
204+
warn=True,
205+
hide=True,
206+
)
207+
output = (result.stdout or "").strip()
208+
if output and result.return_code == 0:
209+
# Array jobs produce one line per task — print only the first
210+
line = output.splitlines()[0]
211+
parts = line.strip().split("|")
212+
if len(parts) >= 3:
213+
_, start_time, state = parts[0].strip(), parts[1].strip(), parts[2].strip()
214+
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
215+
print(
216+
f"[SLURM] Job {job_id} - State: {state}, Estimated start: {start_time}, Current time: {now}",
217+
flush=True,
218+
)
219+
if state.upper() not in ("PENDING", "CF", "CONFIGURING"):
220+
return
221+
else:
222+
print(f"[SLURM] Job {job_id} is no longer pending.", flush=True)
223+
return
224+
except Exception as e:
225+
log.debug(f"Failed to poll start time for job {job_id}: {e}")
226+
227+
delay = min(30 * (2**attempt), 900)
228+
attempt += 1
229+
stop_event.wait(delay)
230+
193231
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayRequest]) -> str: # type: ignore
194232
# Setup
195233
req = dryrun_info.request
@@ -218,6 +256,23 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayReques
218256

219257
# Save metadata
220258
_save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term)
259+
260+
# Stop any existing polling thread for this job_id (retry scenario)
261+
if job_id in self._start_time_stop_events:
262+
self._start_time_stop_events.pop(job_id).set()
263+
self._start_time_threads.pop(job_id, None)
264+
265+
stop_event = threading.Event()
266+
self._start_time_stop_events[job_id] = stop_event
267+
thread = threading.Thread(
268+
target=self._poll_job_start_time,
269+
args=(job_id, self.tunnel, stop_event),
270+
daemon=True,
271+
name=f"slurm-start-time-{job_id}",
272+
)
273+
self._start_time_threads[job_id] = thread
274+
thread.start()
275+
221276
return job_id
222277

223278
def _cancel_existing(self, app_id: str) -> None:
@@ -231,6 +286,10 @@ def _cancel_existing(self, app_id: str) -> None:
231286
assert self.tunnel, "Tunnel is None."
232287
self.tunnel.run(f"scancel {app_id}", hide=False)
233288

289+
if app_id in self._start_time_stop_events:
290+
self._start_time_stop_events.pop(app_id).set()
291+
self._start_time_threads.pop(app_id, None)
292+
234293
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
235294
try:
236295
job_dirs = _get_job_dirs()
@@ -366,7 +425,11 @@ def log_iter(
366425
else:
367426
return [f"Failed getting logs for {app_id}"]
368427

369-
def close(self) -> None: ...
428+
def close(self) -> None:
429+
for stop_event in self._start_time_stop_events.values():
430+
stop_event.set()
431+
self._start_time_threads.clear()
432+
self._start_time_stop_events.clear()
370433

371434

372435
class TunnelLogIterator(LogIterator):

test/run/torchx_backend/schedulers/test_slurm.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import os
2020
import tempfile
21+
import threading
2122
from unittest import mock
2223

2324
import pytest
@@ -143,6 +144,7 @@ def test_schedule(slurm_scheduler, slurm_executor):
143144
with (
144145
mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"),
145146
mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"),
147+
mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"),
146148
):
147149
# Create a fresh mock tunnel for each test to avoid interference
148150
mock_tunnel = mock.MagicMock()
@@ -473,6 +475,7 @@ def test_schedule_with_dependencies(slurm_scheduler, slurm_executor):
473475
mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"),
474476
mock.patch.object(SlurmExecutor, "parse_deps", return_value=["54321"]),
475477
mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"),
478+
mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"),
476479
):
477480
# Create a fresh mock tunnel for testing
478481
mock_tunnel = mock.MagicMock()
@@ -726,3 +729,216 @@ def test_non_heterogeneous_ray_cluster(slurm_scheduler, temp_dir):
726729
# Verify run_as_group was NOT set
727730
assert not hasattr(executor, "run_as_group") or not executor.run_as_group
728731
assert isinstance(dryrun_info.request, SlurmRayRequest)
732+
733+
734+
# ---------------------------------------------------------------------------
735+
# Tests for start-time polling feature
736+
# ---------------------------------------------------------------------------
737+
738+
739+
def test_poll_job_start_time_prints_while_pending(slurm_scheduler, mocker):
740+
job_id = "12345"
741+
stop_event = threading.Event()
742+
mock_tunnel = mock.MagicMock()
743+
mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|PENDING\n"
744+
mock_tunnel.run.return_value.return_code = 0
745+
746+
mock_print = mocker.patch("builtins.print")
747+
748+
# Stop after first iteration by setting the event inside wait
749+
def wait_once(timeout=None):
750+
stop_event.set()
751+
return True
752+
753+
stop_event.wait = wait_once
754+
755+
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
756+
mock_print.assert_called_once()
757+
printed = mock_print.call_args[0][0]
758+
assert job_id in printed
759+
assert "PENDING" in printed
760+
assert "2026-03-14T15:30:00" in printed
761+
assert "Current time:" in printed
762+
763+
764+
def test_poll_job_start_time_stops_when_job_starts(slurm_scheduler, mocker):
765+
job_id = "12345"
766+
stop_event = threading.Event()
767+
mock_tunnel = mock.MagicMock()
768+
mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|RUNNING\n"
769+
mock_tunnel.run.return_value.return_code = 0
770+
771+
mocker.patch("builtins.print")
772+
wait_called = []
773+
original_wait = stop_event.wait
774+
stop_event.wait = lambda t=None: wait_called.append(t) or original_wait(0)
775+
776+
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
777+
assert len(wait_called) == 0 # returned immediately, no wait
778+
779+
780+
def test_poll_job_start_time_stops_when_queue_empty(slurm_scheduler, mocker):
781+
job_id = "12345"
782+
stop_event = threading.Event()
783+
mock_tunnel = mock.MagicMock()
784+
mock_tunnel.run.return_value.stdout = ""
785+
mock_tunnel.run.return_value.return_code = 0
786+
787+
mock_print = mocker.patch("builtins.print")
788+
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
789+
790+
mock_print.assert_called_once()
791+
assert "no longer pending" in mock_print.call_args[0][0]
792+
793+
794+
def test_poll_job_start_time_continues_on_exception(slurm_scheduler, mocker):
795+
job_id = "12345"
796+
stop_event = threading.Event()
797+
mock_tunnel = mock.MagicMock()
798+
# First call raises, second call returns empty to stop the loop
799+
second_result = mock.MagicMock()
800+
second_result.stdout = ""
801+
mock_tunnel.run.side_effect = [
802+
Exception("squeue failed"),
803+
second_result,
804+
]
805+
806+
mocker.patch("builtins.print")
807+
# Patch wait so the inter-poll sleep doesn't block the test (edge case #1)
808+
stop_event.wait = mock.MagicMock(return_value=False)
809+
810+
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
811+
assert mock_tunnel.run.call_count == 2
812+
stop_event.wait.assert_called_once_with(30)
813+
814+
815+
def test_poll_job_start_time_handles_none_stdout(slurm_scheduler, mocker):
816+
job_id = "12345"
817+
stop_event = threading.Event()
818+
mock_tunnel = mock.MagicMock()
819+
mock_tunnel.run.return_value.stdout = None
820+
821+
mock_print = mocker.patch("builtins.print")
822+
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
823+
824+
mock_print.assert_called_once()
825+
assert "no longer pending" in mock_print.call_args[0][0]
826+
827+
828+
def test_poll_job_start_time_skips_nonzero_return_code(slurm_scheduler, mocker):
829+
job_id = "12345"
830+
stop_event = threading.Event()
831+
mock_tunnel = mock.MagicMock()
832+
mock_tunnel.run.return_value.stdout = "slurm_load_jobs error: Invalid job id specified"
833+
mock_tunnel.run.return_value.return_code = 1
834+
835+
mock_print = mocker.patch("builtins.print")
836+
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
837+
838+
mock_print.assert_called_once()
839+
assert "no longer pending" in mock_print.call_args[0][0]
840+
841+
842+
def test_poll_job_start_time_deduplicates_array_job_lines(slurm_scheduler, mocker):
843+
job_id = "12345"
844+
stop_event = threading.Event()
845+
mock_tunnel = mock.MagicMock()
846+
mock_tunnel.run.return_value.stdout = (
847+
f"{job_id}_1|2026-03-14T15:30:00|PENDING\n{job_id}_2|2026-03-14T15:30:00|PENDING\n"
848+
)
849+
mock_tunnel.run.return_value.return_code = 0
850+
851+
mock_print = mocker.patch("builtins.print")
852+
853+
def wait_once(timeout=None):
854+
stop_event.set()
855+
return True
856+
857+
stop_event.wait = wait_once
858+
859+
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
860+
assert mock_print.call_count == 1
861+
862+
863+
def test_schedule_starts_start_time_polling_thread(slurm_scheduler, mocker):
864+
job_id = "99999"
865+
dryrun_info = mock.MagicMock()
866+
867+
mock_tunnel = mock.MagicMock()
868+
mock_tunnel.run.return_value.stdout = job_id
869+
slurm_scheduler.tunnel = mock_tunnel
870+
871+
mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel")
872+
mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir")
873+
874+
# Block the polling thread so is_alive() is True when we check
875+
started = threading.Event()
876+
877+
def blocking_poll(poll_job_id, poll_tunnel, stop_event):
878+
started.set()
879+
stop_event.wait()
880+
881+
mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time", side_effect=blocking_poll)
882+
883+
slurm_scheduler.schedule(dryrun_info)
884+
885+
started.wait(timeout=2)
886+
assert job_id in slurm_scheduler._start_time_threads
887+
thread = slurm_scheduler._start_time_threads[job_id]
888+
assert thread.daemon
889+
assert thread.is_alive()
890+
assert job_id in slurm_scheduler._start_time_stop_events
891+
892+
# Cleanup
893+
slurm_scheduler._start_time_stop_events[job_id].set()
894+
895+
896+
def test_schedule_stops_existing_thread_on_duplicate_job_id(slurm_scheduler, mocker):
897+
job_id = "99999"
898+
old_ev = threading.Event()
899+
slurm_scheduler._start_time_stop_events[job_id] = old_ev
900+
slurm_scheduler._start_time_threads[job_id] = mock.MagicMock()
901+
902+
dryrun_info = mock.MagicMock()
903+
mock_tunnel = mock.MagicMock()
904+
mock_tunnel.run.return_value.stdout = job_id
905+
slurm_scheduler.tunnel = mock_tunnel
906+
907+
mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel")
908+
mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir")
909+
mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time")
910+
911+
slurm_scheduler.schedule(dryrun_info)
912+
913+
assert old_ev.is_set()
914+
assert slurm_scheduler._start_time_stop_events[job_id] is not old_ev # new event
915+
916+
# Cleanup
917+
slurm_scheduler._start_time_stop_events[job_id].set()
918+
919+
920+
def test_close_stops_all_polling_threads(slurm_scheduler):
921+
ev1, ev2 = threading.Event(), threading.Event()
922+
slurm_scheduler._start_time_stop_events = {"1": ev1, "2": ev2}
923+
slurm_scheduler.close()
924+
assert ev1.is_set()
925+
assert ev2.is_set()
926+
assert slurm_scheduler._start_time_threads == {}
927+
assert slurm_scheduler._start_time_stop_events == {}
928+
929+
930+
def test_cancel_stops_polling_thread_for_job(slurm_scheduler, mocker):
931+
job_id = "12345"
932+
ev = threading.Event()
933+
slurm_scheduler._start_time_stop_events[job_id] = ev
934+
slurm_scheduler._start_time_threads[job_id] = mock.MagicMock()
935+
mocker.patch(
936+
"nemo_run.run.torchx_backend.schedulers.slurm._get_job_dirs",
937+
return_value={job_id: ("dir", mock.MagicMock(), "")},
938+
)
939+
slurm_scheduler.tunnel = mock.MagicMock()
940+
941+
slurm_scheduler._cancel_existing(job_id)
942+
943+
assert ev.is_set()
944+
assert job_id not in slurm_scheduler._start_time_stop_events

0 commit comments

Comments
 (0)