Skip to content

Commit a74b5f0

Browse files
authored
ECS Executor: Set tasks to RUNNING state once active (#39212)
Tasks were previously being put into QUEUED state after they were active in the ECS executor. This was to store executor state for task adoption but had the side effect of removing them from the list of running task instances (which has other knock-on effects). Instead, change tasks into the RUNNING state, and do not remove them from the list of running tasks. * Update change_state usage in debug and celery executor - DebugExecutor: was overriding the change_state method from the base executor, but changing no behaviour, so move to using the base executor implementation - CeleryExecutor: Plumb through the new param so that the signature matches the base executor * Call running_state in try/catch for backcompat
1 parent 8965f2e commit a74b5f0

7 files changed

Lines changed: 97 additions & 18 deletions

File tree

airflow/executors/base_executor.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,23 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
303303
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
304304
self.running.add(key)
305305

306-
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
306+
def change_state(
307+
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
308+
) -> None:
307309
"""
308310
Change state of the task.
309311
310-
:param info: Executor information for the task instance
311312
:param key: Unique key for the task instance
312313
:param state: State to set for the task.
314+
:param info: Executor information for the task instance
315+
:param remove_running: Whether or not to remove the TI key from running set
313316
"""
314317
self.log.debug("Changing state: %s", key)
315-
try:
316-
self.running.remove(key)
317-
except KeyError:
318-
self.log.debug("Could not find key: %s", key)
318+
if remove_running:
319+
try:
320+
self.running.remove(key)
321+
except KeyError:
322+
self.log.debug("Could not find key: %s", key)
319323
self.event_buffer[key] = state, info
320324

321325
def fail(self, key: TaskInstanceKey, info=None) -> None:
@@ -345,6 +349,15 @@ def queued(self, key: TaskInstanceKey, info=None) -> None:
345349
"""
346350
self.change_state(key, TaskInstanceState.QUEUED, info)
347351

352+
def running_state(self, key: TaskInstanceKey, info=None) -> None:
353+
"""
354+
Set running state for the event.
355+
356+
:param info: Executor information for the task instance
357+
:param key: Unique key for the task instance
358+
"""
359+
self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False)
360+
348361
def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
349362
"""
350363
Return and flush the event buffer.

airflow/executors/debug_executor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,3 @@ def end(self) -> None:
155155

156156
def terminate(self) -> None:
157157
self._terminated.set()
158-
159-
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
160-
self.log.debug("Popping %s from executor task queue.", key)
161-
self.running.remove(key)
162-
self.event_buffer[key] = state, info

airflow/jobs/scheduler_job_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,12 @@ def _process_executor_events(self, session: Session) -> int:
692692
ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number
693693

694694
self.log.info("Received executor event with state %s for task instance %s", state, ti_key)
695-
if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, TaskInstanceState.QUEUED):
695+
if state in (
696+
TaskInstanceState.FAILED,
697+
TaskInstanceState.SUCCESS,
698+
TaskInstanceState.QUEUED,
699+
TaskInstanceState.RUNNING,
700+
):
696701
tis_with_right_state.append(ti_key)
697702

698703
# Return if no finished tasks
@@ -711,7 +716,7 @@ def _process_executor_events(self, session: Session) -> int:
711716
buffer_key = ti.key.with_try_number(try_number)
712717
state, info = event_buffer.pop(buffer_key)
713718

714-
if state == TaskInstanceState.QUEUED:
719+
if state in (TaskInstanceState.QUEUED, TaskInstanceState.RUNNING):
715720
ti.external_executor_id = info
716721
self.log.info("Setting external_id for %s to %s", ti, info)
717722
continue

airflow/providers/amazon/aws/executors/ecs/ecs_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,12 @@ def attempt_task_runs(self):
400400
else:
401401
task = run_task_response["tasks"][0]
402402
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
403-
self.queued(task_key, task.task_arn)
403+
try:
404+
self.running_state(task_key, task.task_arn)
405+
except AttributeError:
406+
# running_state is newly added, and only needed to support task adoption (an optional
407+
# executor feature).
408+
pass
404409
if failure_reasons:
405410
self.log.error(
406411
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",

airflow/providers/celery/executors/celery_executor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,14 @@ def update_all_task_states(self) -> None:
368368
if state:
369369
self.update_task_state(key, state, info)
370370

371-
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
372-
super().change_state(key, state, info)
371+
def change_state(
372+
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
373+
) -> None:
374+
try:
375+
super().change_state(key, state, info, remove_running=remove_running)
376+
except AttributeError:
377+
# Earlier versions of the BaseExecutor don't accept the remove_running parameter for this method
378+
super().change_state(key, state, info)
373379
self.tasks.pop(key, None)
374380

375381
def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:

tests/executors/test_base_executor.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from airflow.models.baseoperator import BaseOperator
3434
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
3535
from airflow.utils import timezone
36-
from airflow.utils.state import State
36+
from airflow.utils.state import State, TaskInstanceState
3737

3838

3939
def test_supports_sentry():
@@ -363,3 +363,54 @@ def test_running_retry_attempt_type(loop_duration, total_tries):
363363
assert a.elapsed > min_seconds_for_test
364364
assert a.total_tries == total_tries
365365
assert a.tries_after_min == 1
366+
367+
368+
def test_state_fail():
369+
executor = BaseExecutor()
370+
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
371+
executor.running.add(key)
372+
info = "info"
373+
executor.fail(key, info=info)
374+
assert not executor.running
375+
assert executor.event_buffer[key] == (TaskInstanceState.FAILED, info)
376+
377+
378+
def test_state_success():
379+
executor = BaseExecutor()
380+
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
381+
executor.running.add(key)
382+
info = "info"
383+
executor.success(key, info=info)
384+
assert not executor.running
385+
assert executor.event_buffer[key] == (TaskInstanceState.SUCCESS, info)
386+
387+
388+
def test_state_queued():
389+
executor = BaseExecutor()
390+
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
391+
executor.running.add(key)
392+
info = "info"
393+
executor.queued(key, info=info)
394+
assert not executor.running
395+
assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)
396+
397+
398+
def test_state_generic():
399+
executor = BaseExecutor()
400+
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
401+
executor.running.add(key)
402+
info = "info"
403+
executor.queued(key, info=info)
404+
assert not executor.running
405+
assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)
406+
407+
408+
def test_state_running():
409+
executor = BaseExecutor()
410+
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
411+
executor.running.add(key)
412+
info = "info"
413+
executor.running_state(key, info=info)
414+
# Running state should not remove a command as running
415+
assert executor.running
416+
assert executor.event_buffer[key] == (TaskInstanceState.RUNNING, info)

tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def test_stopped_tasks(self):
367367
class TestAwsEcsExecutor:
368368
"""Tests the AWS ECS Executor."""
369369

370-
def test_execute(self, mock_airflow_key, mock_executor):
370+
@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state")
371+
def test_execute(self, change_state_mock, mock_airflow_key, mock_executor):
371372
"""Test execution from end-to-end."""
372373
airflow_key = mock_airflow_key()
373374

@@ -393,6 +394,9 @@ def test_execute(self, mock_airflow_key, mock_executor):
393394
# Task is stored in active worker.
394395
assert 1 == len(mock_executor.active_workers)
395396
assert ARN1 in mock_executor.active_workers.task_by_key(airflow_key).task_arn
397+
change_state_mock.assert_called_once_with(
398+
airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False
399+
)
396400

397401
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
398402
def test_success_execute_api_exception(self, mock_backoff, mock_executor):

0 commit comments

Comments
 (0)