|
18 | 18 | import logging |
19 | 19 | import os |
20 | 20 | import tempfile |
| 21 | +import threading |
21 | 22 | from unittest import mock |
22 | 23 |
|
23 | 24 | import pytest |
@@ -143,6 +144,7 @@ def test_schedule(slurm_scheduler, slurm_executor): |
143 | 144 | with ( |
144 | 145 | mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), |
145 | 146 | mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"), |
| 147 | + mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"), |
146 | 148 | ): |
147 | 149 | # Create a fresh mock tunnel for each test to avoid interference |
148 | 150 | mock_tunnel = mock.MagicMock() |
@@ -473,6 +475,7 @@ def test_schedule_with_dependencies(slurm_scheduler, slurm_executor): |
473 | 475 | mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), |
474 | 476 | mock.patch.object(SlurmExecutor, "parse_deps", return_value=["54321"]), |
475 | 477 | mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"), |
| 478 | + mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"), |
476 | 479 | ): |
477 | 480 | # Create a fresh mock tunnel for testing |
478 | 481 | mock_tunnel = mock.MagicMock() |
@@ -726,3 +729,216 @@ def test_non_heterogeneous_ray_cluster(slurm_scheduler, temp_dir): |
726 | 729 | # Verify run_as_group was NOT set |
727 | 730 | assert not hasattr(executor, "run_as_group") or not executor.run_as_group |
728 | 731 | assert isinstance(dryrun_info.request, SlurmRayRequest) |
| 732 | + |
| 733 | + |
| 734 | +# --------------------------------------------------------------------------- |
| 735 | +# Tests for start-time polling feature |
| 736 | +# --------------------------------------------------------------------------- |
| 737 | + |
| 738 | + |
| 739 | +def test_poll_job_start_time_prints_while_pending(slurm_scheduler, mocker): |
| 740 | + job_id = "12345" |
| 741 | + stop_event = threading.Event() |
| 742 | + mock_tunnel = mock.MagicMock() |
| 743 | + mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|PENDING\n" |
| 744 | + mock_tunnel.run.return_value.return_code = 0 |
| 745 | + |
| 746 | + mock_print = mocker.patch("builtins.print") |
| 747 | + |
| 748 | + # Stop after first iteration by setting the event inside wait |
| 749 | + def wait_once(timeout=None): |
| 750 | + stop_event.set() |
| 751 | + return True |
| 752 | + |
| 753 | + stop_event.wait = wait_once |
| 754 | + |
| 755 | + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) |
| 756 | + mock_print.assert_called_once() |
| 757 | + printed = mock_print.call_args[0][0] |
| 758 | + assert job_id in printed |
| 759 | + assert "PENDING" in printed |
| 760 | + assert "2026-03-14T15:30:00" in printed |
| 761 | + assert "Current time:" in printed |
| 762 | + |
| 763 | + |
| 764 | +def test_poll_job_start_time_stops_when_job_starts(slurm_scheduler, mocker): |
| 765 | + job_id = "12345" |
| 766 | + stop_event = threading.Event() |
| 767 | + mock_tunnel = mock.MagicMock() |
| 768 | + mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|RUNNING\n" |
| 769 | + mock_tunnel.run.return_value.return_code = 0 |
| 770 | + |
| 771 | + mocker.patch("builtins.print") |
| 772 | + wait_called = [] |
| 773 | + original_wait = stop_event.wait |
| 774 | + stop_event.wait = lambda t=None: wait_called.append(t) or original_wait(0) |
| 775 | + |
| 776 | + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) |
| 777 | + assert len(wait_called) == 0 # returned immediately, no wait |
| 778 | + |
| 779 | + |
| 780 | +def test_poll_job_start_time_stops_when_queue_empty(slurm_scheduler, mocker): |
| 781 | + job_id = "12345" |
| 782 | + stop_event = threading.Event() |
| 783 | + mock_tunnel = mock.MagicMock() |
| 784 | + mock_tunnel.run.return_value.stdout = "" |
| 785 | + mock_tunnel.run.return_value.return_code = 0 |
| 786 | + |
| 787 | + mock_print = mocker.patch("builtins.print") |
| 788 | + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) |
| 789 | + |
| 790 | + mock_print.assert_called_once() |
| 791 | + assert "no longer pending" in mock_print.call_args[0][0] |
| 792 | + |
| 793 | + |
| 794 | +def test_poll_job_start_time_continues_on_exception(slurm_scheduler, mocker): |
| 795 | + job_id = "12345" |
| 796 | + stop_event = threading.Event() |
| 797 | + mock_tunnel = mock.MagicMock() |
| 798 | + # First call raises, second call returns empty to stop the loop |
| 799 | + second_result = mock.MagicMock() |
| 800 | + second_result.stdout = "" |
| 801 | + mock_tunnel.run.side_effect = [ |
| 802 | + Exception("squeue failed"), |
| 803 | + second_result, |
| 804 | + ] |
| 805 | + |
| 806 | + mocker.patch("builtins.print") |
| 807 | + # Patch wait so the inter-poll sleep doesn't block the test (edge case #1) |
| 808 | + stop_event.wait = mock.MagicMock(return_value=False) |
| 809 | + |
| 810 | + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) |
| 811 | + assert mock_tunnel.run.call_count == 2 |
| 812 | + stop_event.wait.assert_called_once_with(30) |
| 813 | + |
| 814 | + |
| 815 | +def test_poll_job_start_time_handles_none_stdout(slurm_scheduler, mocker): |
| 816 | + job_id = "12345" |
| 817 | + stop_event = threading.Event() |
| 818 | + mock_tunnel = mock.MagicMock() |
| 819 | + mock_tunnel.run.return_value.stdout = None |
| 820 | + |
| 821 | + mock_print = mocker.patch("builtins.print") |
| 822 | + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) |
| 823 | + |
| 824 | + mock_print.assert_called_once() |
| 825 | + assert "no longer pending" in mock_print.call_args[0][0] |
| 826 | + |
| 827 | + |
| 828 | +def test_poll_job_start_time_skips_nonzero_return_code(slurm_scheduler, mocker): |
| 829 | + job_id = "12345" |
| 830 | + stop_event = threading.Event() |
| 831 | + mock_tunnel = mock.MagicMock() |
| 832 | + mock_tunnel.run.return_value.stdout = "slurm_load_jobs error: Invalid job id specified" |
| 833 | + mock_tunnel.run.return_value.return_code = 1 |
| 834 | + |
| 835 | + mock_print = mocker.patch("builtins.print") |
| 836 | + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) |
| 837 | + |
| 838 | + mock_print.assert_called_once() |
| 839 | + assert "no longer pending" in mock_print.call_args[0][0] |
| 840 | + |
| 841 | + |
| 842 | +def test_poll_job_start_time_deduplicates_array_job_lines(slurm_scheduler, mocker): |
| 843 | + job_id = "12345" |
| 844 | + stop_event = threading.Event() |
| 845 | + mock_tunnel = mock.MagicMock() |
| 846 | + mock_tunnel.run.return_value.stdout = ( |
| 847 | + f"{job_id}_1|2026-03-14T15:30:00|PENDING\n{job_id}_2|2026-03-14T15:30:00|PENDING\n" |
| 848 | + ) |
| 849 | + mock_tunnel.run.return_value.return_code = 0 |
| 850 | + |
| 851 | + mock_print = mocker.patch("builtins.print") |
| 852 | + |
| 853 | + def wait_once(timeout=None): |
| 854 | + stop_event.set() |
| 855 | + return True |
| 856 | + |
| 857 | + stop_event.wait = wait_once |
| 858 | + |
| 859 | + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) |
| 860 | + assert mock_print.call_count == 1 |
| 861 | + |
| 862 | + |
| 863 | +def test_schedule_starts_start_time_polling_thread(slurm_scheduler, mocker): |
| 864 | + job_id = "99999" |
| 865 | + dryrun_info = mock.MagicMock() |
| 866 | + |
| 867 | + mock_tunnel = mock.MagicMock() |
| 868 | + mock_tunnel.run.return_value.stdout = job_id |
| 869 | + slurm_scheduler.tunnel = mock_tunnel |
| 870 | + |
| 871 | + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") |
| 872 | + mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir") |
| 873 | + |
| 874 | + # Block the polling thread so is_alive() is True when we check |
| 875 | + started = threading.Event() |
| 876 | + |
| 877 | + def blocking_poll(poll_job_id, poll_tunnel, stop_event): |
| 878 | + started.set() |
| 879 | + stop_event.wait() |
| 880 | + |
| 881 | + mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time", side_effect=blocking_poll) |
| 882 | + |
| 883 | + slurm_scheduler.schedule(dryrun_info) |
| 884 | + |
| 885 | + started.wait(timeout=2) |
| 886 | + assert job_id in slurm_scheduler._start_time_threads |
| 887 | + thread = slurm_scheduler._start_time_threads[job_id] |
| 888 | + assert thread.daemon |
| 889 | + assert thread.is_alive() |
| 890 | + assert job_id in slurm_scheduler._start_time_stop_events |
| 891 | + |
| 892 | + # Cleanup |
| 893 | + slurm_scheduler._start_time_stop_events[job_id].set() |
| 894 | + |
| 895 | + |
| 896 | +def test_schedule_stops_existing_thread_on_duplicate_job_id(slurm_scheduler, mocker): |
| 897 | + job_id = "99999" |
| 898 | + old_ev = threading.Event() |
| 899 | + slurm_scheduler._start_time_stop_events[job_id] = old_ev |
| 900 | + slurm_scheduler._start_time_threads[job_id] = mock.MagicMock() |
| 901 | + |
| 902 | + dryrun_info = mock.MagicMock() |
| 903 | + mock_tunnel = mock.MagicMock() |
| 904 | + mock_tunnel.run.return_value.stdout = job_id |
| 905 | + slurm_scheduler.tunnel = mock_tunnel |
| 906 | + |
| 907 | + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") |
| 908 | + mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir") |
| 909 | + mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time") |
| 910 | + |
| 911 | + slurm_scheduler.schedule(dryrun_info) |
| 912 | + |
| 913 | + assert old_ev.is_set() |
| 914 | + assert slurm_scheduler._start_time_stop_events[job_id] is not old_ev # new event |
| 915 | + |
| 916 | + # Cleanup |
| 917 | + slurm_scheduler._start_time_stop_events[job_id].set() |
| 918 | + |
| 919 | + |
| 920 | +def test_close_stops_all_polling_threads(slurm_scheduler): |
| 921 | + ev1, ev2 = threading.Event(), threading.Event() |
| 922 | + slurm_scheduler._start_time_stop_events = {"1": ev1, "2": ev2} |
| 923 | + slurm_scheduler.close() |
| 924 | + assert ev1.is_set() |
| 925 | + assert ev2.is_set() |
| 926 | + assert slurm_scheduler._start_time_threads == {} |
| 927 | + assert slurm_scheduler._start_time_stop_events == {} |
| 928 | + |
| 929 | + |
| 930 | +def test_cancel_stops_polling_thread_for_job(slurm_scheduler, mocker): |
| 931 | + job_id = "12345" |
| 932 | + ev = threading.Event() |
| 933 | + slurm_scheduler._start_time_stop_events[job_id] = ev |
| 934 | + slurm_scheduler._start_time_threads[job_id] = mock.MagicMock() |
| 935 | + mocker.patch( |
| 936 | + "nemo_run.run.torchx_backend.schedulers.slurm._get_job_dirs", |
| 937 | + return_value={job_id: ("dir", mock.MagicMock(), "")}, |
| 938 | + ) |
| 939 | + slurm_scheduler.tunnel = mock.MagicMock() |
| 940 | + |
| 941 | + slurm_scheduler._cancel_existing(job_id) |
| 942 | + |
| 943 | + assert ev.is_set() |
| 944 | + assert job_id not in slurm_scheduler._start_time_stop_events |
0 commit comments