diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 586a2053b83eb..8640faae8cf7b 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -32,7 +32,7 @@ from itertools import groupby from typing import TYPE_CHECKING, Any, Callable -from sqlalchemy import and_, delete, exists, func, select, text, tuple_, update +from sqlalchemy import and_, delete, desc, exists, func, select, text, tuple_, update from sqlalchemy.exc import OperationalError from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload from sqlalchemy.sql import expression @@ -2043,19 +2043,34 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE We can then use this information to determine whether to reschedule a task or fail it. """ - return ( - session.query(Log) + last_running_time = session.scalar( + select(Log.dttm) .where( - Log.task_id == ti.task_id, Log.dag_id == ti.dag_id, + Log.task_id == ti.task_id, Log.run_id == ti.run_id, Log.map_index == ti.map_index, Log.try_number == ti.try_number, - Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT, + Log.event == "running", ) - .count() + .order_by(desc(Log.dttm)) + .limit(1) + ) + + query = session.query(Log).where( + Log.task_id == ti.task_id, + Log.dag_id == ti.dag_id, + Log.run_id == ti.run_id, + Log.map_index == ti.map_index, + Log.try_number == ti.try_number, + Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT, ) + if last_running_time: + query = query.where(Log.dttm > last_running_time) + + return query.count() + previous_ti_running_metrics: dict[tuple[str, str, str], int] = {} @provide_session diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index ac492590e5b65..f6601783d6081 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -2078,6 +2078,105 @@ def _queue_tasks(tis): states = [x.state for x in dr.get_task_instances(session=session)] assert states == ["failed", "failed"] + @conf_vars({("scheduler", "num_stuck_in_queued_retries"): "2"}) + def test_handle_stuck_queued_tasks_reschedule_sensors(self, dag_maker, session, mock_executors): + """Reschedule sensors go in and out of running repeatedly using the same try_number + Make sure that they get three attempts per reschedule, not 3 attempts per try_number""" + with dag_maker("test_fail_stuck_queued_tasks_multiple_executors"): + EmptyOperator(task_id="op1") + EmptyOperator(task_id="op2", executor="default_exec") + + def _queue_tasks(tis): + for ti in tis: + ti.state = "queued" + ti.queued_dttm = timezone.utcnow() + session.commit() + + def _add_running_event(tis): + for ti in tis: + updated_entry = Log( + dttm=timezone.utcnow(), + dag_id=ti.dag_id, + task_id=ti.task_id, + map_index=ti.map_index, + event="running", + run_id=ti.run_id, + try_number=ti.try_number, + ) + session.add(updated_entry) + + run_id = str(uuid4()) + dr = dag_maker.create_dagrun(run_id=run_id) + + tis = dr.get_task_instances(session=session) + _queue_tasks(tis=tis) + scheduler_job = Job() + scheduler = SchedulerJobRunner(job=scheduler_job, num_runs=0) + # job_runner._reschedule_stuck_task = MagicMock() + scheduler._task_queued_timeout = -300 # always in violation of timeout + + with _loader_mock(mock_executors): + scheduler._handle_tasks_stuck_in_queued() + # If the task gets stuck in queued once, we reset it to scheduled + tis = dr.get_task_instances(session=session) + assert [x.state for x in tis] == ["scheduled", "scheduled"] + assert [x.queued_dttm for x in tis] == [None, None] + + _queue_tasks(tis=tis) + log_events = [ + x.event for x in session.scalars(select(Log).where(Log.run_id == run_id).order_by(Log.id)).all() + ] + assert log_events == [ + "stuck in queued reschedule", + "stuck in queued reschedule", + ] + + with _loader_mock(mock_executors): + scheduler._handle_tasks_stuck_in_queued() + + log_events = [ + x.event for x in session.scalars(select(Log).where(Log.run_id == run_id).order_by(Log.id)).all() + ] + assert log_events == [ + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + ] + mock_executors[0].fail.assert_not_called() + tis = dr.get_task_instances(session=session) + assert [x.state for x in tis] == ["scheduled", "scheduled"] + + _add_running_event(tis) # This should "reset" the count of stuck queued + + for _ in range(3): # Should be able to be stuck 3 more times before failing + _queue_tasks(tis=tis) + with _loader_mock(mock_executors): + scheduler._handle_tasks_stuck_in_queued() + tis = dr.get_task_instances(session=session) + + log_events = [ + x.event for x in session.scalars(select(Log).where(Log.run_id == run_id).order_by(Log.id)).all() + ] + assert log_events == [ + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "running", + "running", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued tries exceeded", + "stuck in queued tries exceeded", + ] + + mock_executors[0].fail.assert_not_called() # just demoing that we don't fail with executor method + states = [x.state for x in dr.get_task_instances(session=session)] + assert states == ["failed", "failed"] + def test_revoke_task_not_imp_tolerated(self, dag_maker, session, caplog): """Test that if executor no implement revoke_task then we don't blow up.""" with dag_maker("test_fail_stuck_queued_tasks"):