|
23 | 23 | import selectors |
24 | 24 | import time |
25 | 25 | from collections.abc import AsyncIterator |
| 26 | +from concurrent.futures import Future |
26 | 27 | from typing import TYPE_CHECKING, Any |
27 | 28 | from unittest.mock import ANY, MagicMock, patch |
28 | 29 |
|
@@ -326,6 +327,32 @@ async def test_invalid_trigger(self, supervisor_builder): |
326 | 327 | assert trigger_id == 1 |
327 | 328 | assert traceback[-1] == "ModuleNotFoundError: No module named 'fake'\n" |
328 | 329 |
|
| 330 | + @pytest.mark.asyncio |
| 331 | + @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) |
| 332 | + async def test_sync_state_to_supervisor_calls_stdin_threadpool_executor(self, mock_supervisor_comms): |
| 333 | + workload = workloads.RunTrigger.model_construct( |
| 334 | + id=1, ti=None, classpath="fake.classpath", encrypted_kwargs={} |
| 335 | + ) |
| 336 | + |
| 337 | + trigger_runner = TriggerRunner() |
| 338 | + trigger_runner.requests_sock = MagicMock() |
| 339 | + trigger_runner.to_create.append(workload) |
| 340 | + |
| 341 | + await trigger_runner.create_triggers() |
| 342 | + ids = await trigger_runner.cleanup_finished_triggers() |
| 343 | + |
| 344 | + future = Future() |
| 345 | + read_line_output = b'{"type": "TriggerStateSync", "to_create": [], "to_cancel": []}\n' |
| 346 | + future.set_result(read_line_output) |
| 347 | + |
| 348 | + with patch.object(trigger_runner, "_stdin_threadpool_executor", MagicMock()) as mock_executor: |
| 349 | + mock_executor.submit.return_value = future |
| 350 | + |
| 351 | + await trigger_runner.sync_state_to_supervisor(ids) |
| 352 | + |
| 353 | + # Assert that _stdin_threadpool_executor.submit was called with the correct task |
| 354 | + mock_executor.submit.assert_called_once_with(mock_supervisor_comms._read_stdin_line) |
| 355 | + |
329 | 356 |
|
330 | 357 | @pytest.mark.asyncio |
331 | 358 | async def test_trigger_create_race_condition_38599(session, supervisor_builder): |
|
0 commit comments