diff --git a/airflow-core/docs/img/airflow_erd.sha256 b/airflow-core/docs/img/airflow_erd.sha256 index ec5a663b3429b..fb89acbe007e2 100644 --- a/airflow-core/docs/img/airflow_erd.sha256 +++ b/airflow-core/docs/img/airflow_erd.sha256 @@ -1 +1 @@ -20aac73bfb0e08226c687304b997ed93d9cf796d3b761d14e45d732ef4f6c01e \ No newline at end of file +db00d57fce32830b69f2c1481b231e65e67e197b4a96a5fa1c870cd555eac3bd \ No newline at end of file diff --git a/airflow-core/docs/img/airflow_erd.svg b/airflow-core/docs/img/airflow_erd.svg index dbcc7d4060ec2..7398a42070836 100644 --- a/airflow-core/docs/img/airflow_erd.svg +++ b/airflow-core/docs/img/airflow_erd.svg @@ -714,169 +714,164 @@ task_instance - -task_instance - -id - - [UUID] - NOT NULL - -context_carrier - - [JSONB] - -custom_operator_name - - [VARCHAR(1000)] - -dag_id - - [VARCHAR(250)] - NOT NULL - -dag_version_id - - [UUID] - -duration - - [DOUBLE_PRECISION] - -end_date - - [TIMESTAMP] - -executor - - [VARCHAR(1000)] - -executor_config - - [BYTEA] - -external_executor_id - - [VARCHAR(250)] - -hostname - - [VARCHAR(1000)] - -last_heartbeat_at - - [TIMESTAMP] - -map_index - - [INTEGER] - NOT NULL - -max_tries - - [INTEGER] - -next_kwargs - - [JSONB] - -next_method - - [VARCHAR(1000)] - -operator - - [VARCHAR(1000)] - -pid - - [INTEGER] - -pool - - [VARCHAR(256)] - NOT NULL - -pool_slots - - [INTEGER] - NOT NULL - -priority_weight - - [INTEGER] - -queue - - [VARCHAR(256)] - -queued_by_job_id - - [INTEGER] - -queued_dttm - - [TIMESTAMP] - -rendered_map_index - - [VARCHAR(250)] - -run_id - - [VARCHAR(250)] - NOT NULL - -scheduled_dttm - - [TIMESTAMP] - -span_status - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -task_display_name - - [VARCHAR(2000)] - -task_id - - [VARCHAR(250)] - NOT NULL - -trigger_id - - [INTEGER] - -trigger_timeout - - [TIMESTAMP] - -try_id - - [UUID] - NOT NULL - -try_number - - [INTEGER] - -unixname - - [VARCHAR(1000)] - -updated_at - - [TIMESTAMP] + +task_instance + +id + + [UUID] + NOT NULL + +context_carrier + + [JSONB] + +custom_operator_name + + [VARCHAR(1000)] + +dag_id + + [VARCHAR(250)] + NOT NULL + +dag_version_id + + [UUID] + +duration + + [DOUBLE_PRECISION] + +end_date + + [TIMESTAMP] + +executor + + [VARCHAR(1000)] + +executor_config + + [BYTEA] + +external_executor_id + + [VARCHAR(250)] + +hostname + + [VARCHAR(1000)] + +last_heartbeat_at + + [TIMESTAMP] + +map_index + + [INTEGER] + NOT NULL + +max_tries + + [INTEGER] + +next_kwargs + + [JSONB] + +next_method + + [VARCHAR(1000)] + +operator + + [VARCHAR(1000)] + +pid + + [INTEGER] + +pool + + [VARCHAR(256)] + NOT NULL + +pool_slots + + [INTEGER] + NOT NULL + +priority_weight + + [INTEGER] + +queue + + [VARCHAR(256)] + +queued_by_job_id + + [INTEGER] + +queued_dttm + + [TIMESTAMP] + +rendered_map_index + + [VARCHAR(250)] + +run_id + + [VARCHAR(250)] + NOT NULL + +scheduled_dttm + + [TIMESTAMP] + +span_status + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +task_display_name + + [VARCHAR(2000)] + +task_id + + [VARCHAR(250)] + NOT NULL + +trigger_id + + [INTEGER] + +trigger_timeout + + [TIMESTAMP] + +try_number + + [INTEGER] + +unixname + + [VARCHAR(1000)] + +updated_at + + [TIMESTAMP] @@ -888,477 +883,467 @@ rendered_task_instance_fields - -rendered_task_instance_fields - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -k8s_pod_yaml - - [JSON] - -rendered_fields - - [JSON] - NOT NULL + +rendered_task_instance_fields + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +k8s_pod_yaml + + [JSON] + +rendered_fields + + [JSON] + NOT NULL task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_map - -task_map - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -keys - - [JSONB] - -length - - [INTEGER] - NOT NULL + +task_map + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +keys + + [JSONB] + +length + + [INTEGER] + NOT NULL task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_reschedule - -task_reschedule - -id - - [INTEGER] - NOT NULL - -duration - - [INTEGER] - NOT NULL - -end_date - - [TIMESTAMP] - NOT NULL - -reschedule_date - - [TIMESTAMP] - NOT NULL - -start_date - - [TIMESTAMP] - NOT NULL - -ti_id - - [UUID] - NOT NULL - -try_number - - [INTEGER] - NOT NULL + +task_reschedule + +id + + [INTEGER] + NOT NULL + +duration + + [INTEGER] + NOT NULL + +end_date + + [TIMESTAMP] + NOT NULL + +reschedule_date + + [TIMESTAMP] + NOT NULL + +start_date + + [TIMESTAMP] + NOT NULL + +ti_id + + [UUID] + NOT NULL task_instance--task_reschedule - -0..N -1 + +0..N +1 xcom - -xcom - -dag_run_id - - [INTEGER] - NOT NULL - -key - - [VARCHAR(512)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -timestamp - - [TIMESTAMP] - NOT NULL - -value - - [JSONB] + +xcom + +dag_run_id + + [INTEGER] + NOT NULL + +key + + [VARCHAR(512)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +timestamp + + [TIMESTAMP] + NOT NULL + +value + + [JSONB] task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance_note - -task_instance_note - -ti_id - - [UUID] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [VARCHAR(128)] + +task_instance_note + +ti_id + + [UUID] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [VARCHAR(128)] task_instance--task_instance_note - -1 -1 + +1 +1 task_instance_history - -task_instance_history - -try_id - - [UUID] - NOT NULL - -context_carrier - - [JSONB] - -custom_operator_name - - [VARCHAR(1000)] - -dag_id - - [VARCHAR(250)] - NOT NULL - -dag_version_id - - [UUID] - -duration - - [DOUBLE_PRECISION] - -end_date - - [TIMESTAMP] - -executor - - [VARCHAR(1000)] - -executor_config - - [BYTEA] - -external_executor_id - - [VARCHAR(250)] - -hostname - - [VARCHAR(1000)] - -map_index - - [INTEGER] - NOT NULL - -max_tries - - [INTEGER] - -next_kwargs - - [JSONB] - -next_method - - [VARCHAR(1000)] - -operator - - [VARCHAR(1000)] - -pid - - [INTEGER] - -pool - - [VARCHAR(256)] - NOT NULL - -pool_slots - - [INTEGER] - NOT NULL - -priority_weight - - [INTEGER] - -queue - - [VARCHAR(256)] - -queued_by_job_id - - [INTEGER] - -queued_dttm - - [TIMESTAMP] - -rendered_map_index - - [VARCHAR(250)] - -run_id - - [VARCHAR(250)] - NOT NULL - -scheduled_dttm - - [TIMESTAMP] - -span_status - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -task_display_name - - [VARCHAR(2000)] - -task_id - - [VARCHAR(250)] - NOT NULL - -task_instance_id - - [UUID] - NOT NULL - -trigger_id - - [INTEGER] - -trigger_timeout - - [TIMESTAMP] - -try_number - - [INTEGER] - NOT NULL - -unixname - - [VARCHAR(1000)] - -updated_at - - [TIMESTAMP] + +task_instance_history + +task_instance_id + + [UUID] + NOT NULL + +context_carrier + + [JSONB] + +custom_operator_name + + [VARCHAR(1000)] + +dag_id + + [VARCHAR(250)] + NOT NULL + +dag_version_id + + [UUID] + +duration + + [DOUBLE_PRECISION] + +end_date + + [TIMESTAMP] + +executor + + [VARCHAR(1000)] + +executor_config + + [BYTEA] + +external_executor_id + + [VARCHAR(250)] + +hostname + + [VARCHAR(1000)] + +map_index + + [INTEGER] + NOT NULL + +max_tries + + [INTEGER] + +next_kwargs + + [JSONB] + +next_method + + [VARCHAR(1000)] + +operator + + [VARCHAR(1000)] + +pid + + [INTEGER] + +pool + + [VARCHAR(256)] + NOT NULL + +pool_slots + + [INTEGER] + NOT NULL + +priority_weight + + [INTEGER] + +queue + + [VARCHAR(256)] + +queued_by_job_id + + [INTEGER] + +queued_dttm + + [TIMESTAMP] + +rendered_map_index + + [VARCHAR(250)] + +run_id + + [VARCHAR(250)] + NOT NULL + +scheduled_dttm + + [TIMESTAMP] + +span_status + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +task_display_name + + [VARCHAR(2000)] + +task_id + + [VARCHAR(250)] + NOT NULL + +trigger_id + + [INTEGER] + +trigger_timeout + + [TIMESTAMP] + +try_number + + [INTEGER] + NOT NULL + +unixname + + [VARCHAR(1000)] + +updated_at + + [TIMESTAMP] task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 @@ -2073,8 +2058,8 @@ dag_version--task_instance - -0..N + +0..N {0,1} diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index dbed28f94864e..2f35cf2deecf4 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``959e216a3abb`` (head) | ``0e9519b56710`` | ``3.0.0`` | Rename ``is_active`` to ``is_stale`` column in ``dag`` | +| ``29ce7909c52b`` (head) | ``959e216a3abb`` | ``3.0.0`` | Change TI table to have unique UUID id/pk per attempt. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``959e216a3abb`` | ``0e9519b56710`` | ``3.0.0`` | Rename ``is_active`` to ``is_stale`` column in ``dag`` | | | | | table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``0e9519b56710`` | ``ec62e120484d`` | ``3.0.0`` | Rename run_type from 'dataset_triggered' to | diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 6611aeb56de46..ef7a6bf6d60e9 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -213,7 +213,7 @@ def ti_run( session.query( func.count(TaskReschedule.id) # or any other primary key column ) - .filter(TaskReschedule.ti_id == ti_id_str, TaskReschedule.try_number == ti.try_number) + .filter(TaskReschedule.ti_id == ti_id_str) .scalar() or 0 ) @@ -309,13 +309,9 @@ def ti_update_state( query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) query = query.values(state=updated_state) elif isinstance(ti_patch_payload, TIRetryStatePayload): - from airflow.models.taskinstance import uuid7 - from airflow.models.taskinstancehistory import TaskInstanceHistory - ti = session.get(TI, ti_id_str) - TaskInstanceHistory.record_ti(ti, session=session) - ti.try_id = uuid7() updated_state = ti_patch_payload.state + ti.prepare_db_for_next_try(session) query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) query = query.values(state=updated_state) elif isinstance(ti_patch_payload, TISuccessStatePayload): @@ -393,7 +389,6 @@ def ti_update_state( session.add( TaskReschedule( task_instance.id, - task_instance.try_number, actual_start_date, ti_patch_payload.end_date, ti_patch_payload.reschedule_date, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py index 0c023f2e2711f..d3e940f47a08f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py @@ -17,10 +17,9 @@ from __future__ import annotations -from typing import Annotated from uuid import UUID -from fastapi import Query, status +from fastapi import status from sqlalchemy import select from airflow.api_fastapi.common.db.common import SessionDep @@ -37,18 +36,12 @@ @router.get("/{task_instance_id}/start_date") -def get_start_date( - task_instance_id: UUID, session: SessionDep, try_number: Annotated[int, Query()] = 1 -) -> UtcDateTime | None: +def get_start_date(task_instance_id: UUID, session: SessionDep) -> UtcDateTime | None: """Get the first reschedule date if found, None if no records exist.""" start_date = session.scalar( - select(TaskReschedule) - .where( - TaskReschedule.ti_id == str(task_instance_id), - TaskReschedule.try_number >= try_number, - ) + select(TaskReschedule.start_date) + .where(TaskReschedule.ti_id == str(task_instance_id)) .order_by(TaskReschedule.id.asc()) - .with_only_columns(TaskReschedule.start_date) .limit(1) ) diff --git a/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py b/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py new file mode 100644 index 0000000000000..e0af8511ca724 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py @@ -0,0 +1,130 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Change TI table to have unique UUID id/pk per attempt. + +Revision ID: 29ce7909c52b +Revises: 959e216a3abb +Create Date: 2025-04-09 10:09:53.130924 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy_utils import UUIDType + +# revision identifiers, used by Alembic. +revision = "29ce7909c52b" +down_revision = "959e216a3abb" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + + +def _get_uuid_type(dialect_name: str) -> sa.types.TypeEngine: + if dialect_name == "sqlite": + return sa.String(36) + else: + return UUIDType(binary=False) + + +def upgrade(): + """Apply Change TI table to have unique UUID id/pk per attempt.""" + conn = op.get_bind() + dialect_name = conn.dialect.name + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.drop_constraint("task_instance_try_id_uq", type_="unique") + batch_op.drop_column("try_id") + + with op.batch_alter_table("task_instance_history", schema=None) as batch_op: + batch_op.create_index("idx_tih_dag_run", ["dag_id", "run_id"], unique=False) + batch_op.drop_column("task_instance_id") + batch_op.alter_column( + "try_id", + new_column_name="task_instance_id", + existing_type=_get_uuid_type(dialect_name), + existing_nullable=False, + ) + + with op.batch_alter_table("task_instance_note", schema=None) as batch_op: + batch_op.drop_constraint("task_instance_note_ti_fkey", type_="foreignkey") + batch_op.create_foreign_key( + "task_instance_note_ti_fkey", + "task_instance", + ["ti_id"], + ["id"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + + # We decided to not migrate/correct the data for task_reschedule as we decided the time to do it wasn't + # worth it, as 90%+ of the data in the table is not needed -- this table only has use when a Sensor task, + # with reschedule mode, is in the `running` or `up_for_reschedule` states. + # + # Going forward, Airflow will delete rows from this table when the TI is recorded in the history table + # (i.e. when it's cleared or retired) so the only case in which this data will be accessed and give an + # incorrect figure is when all of these hold true + # + # - a Sensor in reschedule mode + # - in running or up_for_reschedule states + # - On a try_number > 1 + # - with a timeout set + # + # If all of these are true, then the total runtime will be mistakenly calculated as the start of the first + # try to now, rather than the start of the current try. But this is such an unlikely set of circumstances + # that it's not worth the time cost of migrating it. + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: + batch_op.drop_column("try_number") + + +def downgrade(): + """Unapply Change TI table to have unique UUID id/pk per attempt.""" + conn = op.get_bind() + dialect_name = conn.dialect.name + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: + batch_op.add_column( + sa.Column("try_number", sa.INTEGER(), autoincrement=False, nullable=False, default=1) + ) + + with op.batch_alter_table("task_instance_note", schema=None) as batch_op: + batch_op.drop_constraint("task_instance_note_ti_fkey", type_="foreignkey") + batch_op.create_foreign_key( + "task_instance_note_ti_fkey", "task_instance", ["ti_id"], ["id"], ondelete="CASCADE" + ) + + with op.batch_alter_table("task_instance_history", schema=None) as batch_op: + batch_op.alter_column( + "task_instance_id", + new_column_name="try_id", + existing_type=_get_uuid_type(dialect_name), + existing_nullable=False, + ) + batch_op.drop_index("idx_tih_dag_run") + # This has to be in a separate batch, else on sqlite it throws `sqlalchemy.exc.CircularDependencyError` + # (and on non sqlite batching isn't "a thing", it issue alter tables fine) + with op.batch_alter_table("task_instance_history", schema=None) as batch_op: + batch_op.add_column( + sa.Column("task_instance_id", UUIDType(binary=False), autoincrement=False, nullable=False) + ) + + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.add_column(sa.Column("try_id", _get_uuid_type(dialect_name), nullable=False)) + batch_op.create_unique_constraint("task_instance_try_id_uq", ["try_id"]) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 3f668e1618fd3..df4ecf6f56e65 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -49,7 +49,6 @@ not_, or_, text, - tuple_, update, ) from sqlalchemy.dialects import postgresql @@ -1831,7 +1830,7 @@ def schedule_tis( """ # Get list of TI IDs that do not need to executed, these are # tasks using EmptyOperator and without on_execute_callback / on_success_callback - dummy_ti_ids = [] + empty_ti_ids = [] schedulable_ti_ids = [] for ti in schedulable_tis: if TYPE_CHECKING: @@ -1842,7 +1841,7 @@ def schedule_tis( and not ti.task.on_success_callback and not ti.task.outlets ): - dummy_ti_ids.append((ti.task_id, ti.map_index)) + empty_ti_ids.append(ti.id) # check "start_trigger_args" to see whether the operator supports start execution from triggerer # if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer # this task. @@ -1857,9 +1856,9 @@ def schedule_tis( ti.try_number += 1 ti.defer_task(exception=None, session=session) else: - schedulable_ti_ids.append((ti.task_id, ti.map_index)) + schedulable_ti_ids.append(ti.id) else: - schedulable_ti_ids.append((ti.task_id, ti.map_index)) + schedulable_ti_ids.append(ti.id) count = 0 @@ -1867,14 +1866,10 @@ def schedule_tis( schedulable_ti_ids_chunks = chunks( schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids) ) - for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks: + for id_chunk in schedulable_ti_ids_chunks: count += session.execute( update(TI) - .where( - TI.dag_id == self.dag_id, - TI.run_id == self.run_id, - tuple_(TI.task_id, TI.map_index).in_(schedulable_ti_ids_chunk), - ) + .where(TI.id.in_(id_chunk)) .values( state=TaskInstanceState.SCHEDULED, scheduled_dttm=timezone.utcnow(), @@ -1890,16 +1885,12 @@ def schedule_tis( ).rowcount # Tasks using EmptyOperator should not be executed, mark them as success - if dummy_ti_ids: - dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or len(dummy_ti_ids)) - for dummy_ti_ids_chunk in dummy_ti_ids_chunks: + if empty_ti_ids: + dummy_ti_ids_chunks = chunks(empty_ti_ids, max_tis_per_query or len(empty_ti_ids)) + for id_chunk in dummy_ti_ids_chunks: count += session.execute( update(TI) - .where( - TI.dag_id == self.dag_id, - TI.run_id == self.run_id, - tuple_(TI.task_id, TI.map_index).in_(dummy_ti_ids_chunk), - ) + .where(TI.id.in_(id_chunk)) .values( state=TaskInstanceState.SUCCESS, start_date=timezone.utcnow(), diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 11a40181f7861..f15b8e9b918b9 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -447,12 +447,10 @@ def clear_task_instances( # taskinstance uuids: task_instance_ids: list[str] = [] dag_bag = DagBag(read_dags_from_db=True) - from airflow.models.taskinstancehistory import TaskInstanceHistory for ti in tis: task_instance_ids.append(ti.id) - TaskInstanceHistory.record_ti(ti, session) - ti.try_id = uuid7() + ti.prepare_db_for_next_try(session) if ti.state == TaskInstanceState.RUNNING: # If a task is cleared when running, set its state to RESTARTING so that # the task is terminated and becomes eligible for retry. @@ -477,11 +475,6 @@ def clear_task_instances( ti.clear_next_method_args() session.merge(ti) - if task_instance_ids: - # Clear all reschedules related to the ti to clear - delete_qry = TR.__table__.delete().where(TR.ti_id.in_(task_instance_ids)) - session.execute(delete_qry) - if dag_run_state is not False and tis: from airflow.models.dagrun import DagRun # Avoid circular import @@ -1436,7 +1429,6 @@ def _handle_reschedule( session.add( TaskReschedule( ti.id, - ti.try_number, actual_start_date, ti.end_date, reschedule_exception.reschedule_date, @@ -1486,7 +1478,6 @@ class TaskInstance(Base, LoggingMixin): end_date = Column(UtcDateTime) duration = Column(Float) state = Column(String(20)) - try_id = Column(UUIDType(binary=False), default=uuid7, unique=True, nullable=False) try_number = Column(Integer, default=0) max_tries = Column(Integer, server_default=text("-1")) hostname = Column(String(1000)) @@ -1945,33 +1936,48 @@ def refresh_from_db( :param keep_local_changes: Force all attributes to the values from the database if False (the default), or if True don't overwrite locally set attributes """ - source = TaskInstance.get_task_instance( + query = select( + # Select the columns, not the ORM object, to bypass any session/ORM caching layer + c + for c in TaskInstance.__table__.columns + ).filter_by( dag_id=self.dag_id, - task_id=self.task_id, run_id=self.run_id, + task_id=self.task_id, map_index=self.map_index, - lock_for_update=lock_for_update, - session=session, ) - if source: - from sqlalchemy.orm import attributes - source_state = inspect(source) - if source_state is None: - raise RuntimeError(f"Unable to inspect SQLAlchemy state of {type(source)}: {source}") + if lock_for_update: + query = query.with_for_update() + + source = session.execute(query).mappings().one_or_none() + if source: target_state = inspect(self) if target_state is None: raise RuntimeError(f"Unable to inspect SQLAlchemy state of {type(self)}: {self}") - for name, attr in source_state.attrs.items(): - if keep_local_changes and target_state.attrs[name].history.has_changes(): + + # To deal with `@hybrid_property` we need to get the names from `mapper.columns` + for attr_name, col in target_state.mapper.columns.items(): + if keep_local_changes and target_state.attrs[attr_name].history.has_changes(): continue - val = attr.loaded_value + set_committed_value(self, attr_name, source[col.name]) - if val is not attributes.NO_VALUE: - set_committed_value(self, name, val) + # ID may have changed, update SQLAs state and object tracking + newkey = session.identity_key(type(self), (self.id,)) + + # Delete anything under the new key + if newkey != target_state.key: + old = session.identity_map.get(newkey) + if old is not self and old is not None: + session.expunge(old) + target_state.key = newkey + + if target_state.attrs.dag_run.loaded_value is not NO_VALUE: + dr_key = session.identity_key(type(self.dag_run), (self.dag_run.id,)) + if (dr := session.identity_map.get(dr_key)) is not None: + set_committed_value(self, "dag_run", dr) - target_state.key = source_state.key else: self.state = None @@ -2019,32 +2025,6 @@ def key(self) -> TaskInstanceKey: """Returns a tuple that identifies the task instance uniquely.""" return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) - @staticmethod - def _set_state(ti: TaskInstance, state, session: Session) -> bool: - if not isinstance(ti, TaskInstance): - ti = session.scalars( - select(TaskInstance).where( - TaskInstance.task_id == ti.task_id, - TaskInstance.dag_id == ti.dag_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.map_index == ti.map_index, - ) - ).one() - - if ti.state == state: - return False - - current_time = timezone.utcnow() - ti.log.debug("Setting task state for %s to %s", ti, state) - ti.state = state - ti.start_date = ti.start_date or current_time - if ti.state in State.finished or ti.state == TaskInstanceState.UP_FOR_RETRY: - ti.end_date = ti.end_date or current_time - ti.duration = (ti.end_date - ti.start_date).total_seconds() - - session.merge(ti) - return True - @provide_session def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: """ @@ -2054,7 +2034,21 @@ def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: :param session: SQLAlchemy ORM Session :return: Was the state changed """ - return self._set_state(ti=self, state=state, session=session) + if self.state == state: + return False + + current_time = timezone.utcnow() + self.log.debug("Setting task state for %s to %s", self, state) + if self not in session: + self.refresh_from_db(session) + self.state = state + self.start_date = self.start_date or current_time + if self.state in State.finished or self.state == TaskInstanceState.UP_FOR_RETRY: + self.end_date = self.end_date or current_time + self.duration = (self.end_date - self.start_date).total_seconds() + session.merge(self) + session.flush() + return True @property def is_premature(self) -> bool: @@ -2062,6 +2056,14 @@ def is_premature(self) -> bool: # is the task still in the retry waiting period? return self.state == TaskInstanceState.UP_FOR_RETRY and not self.ready_for_retry() + def prepare_db_for_next_try(self, session: Session): + """Update the metadata with all the records needed to put this TI in queued for the next try.""" + from airflow.models.taskinstancehistory import TaskInstanceHistory + + TaskInstanceHistory.record_ti(self, session=session) + session.execute(delete(TaskReschedule).filter_by(ti_id=self.id)) + self.id = uuid7() + @provide_session def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: """ @@ -2994,10 +2996,7 @@ def fetch_handle_failure_context( # If the task instance is in the running state, it means it raised an exception and # about to retry so we record the task instance history. For other states, the task # instance was cleared and already recorded in the task instance history. - from airflow.models.taskinstancehistory import TaskInstanceHistory - - TaskInstanceHistory.record_ti(ti, session=session) - ti.try_id = uuid7() + ti.prepare_db_for_next_try(session) ti.state = State.UP_FOR_RETRY email_for_state = operator.attrgetter("email_on_retry") @@ -3553,21 +3552,14 @@ def _get_inactive_asset_unique_keys( def get_first_reschedule_date(self, context: Context) -> datetime | None: """Get the first reschedule date for the task instance.""" - # TODO: AIP-72: Remove this after `ti.run` is migrated to use Task SDK - max_tries: int = self.max_tries or 0 - if TYPE_CHECKING: assert isinstance(self.task, BaseOperator) - retries: int = self.task.retries or 0 - first_try_number = max_tries - retries + 1 - with create_session() as session: start_date = session.scalar( select(TaskReschedule) .where( TaskReschedule.ti_id == str(self.id), - TaskReschedule.try_number >= first_try_number, ) .order_by(TaskReschedule.id.asc()) .with_only_columns(TaskReschedule.start_date) @@ -3715,6 +3707,7 @@ class TaskInstanceNote(Base): ], name="task_instance_note_ti_fkey", ondelete="CASCADE", + onupdate="CASCADE", ), ) diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py b/airflow-core/src/airflow/models/taskinstancehistory.py index 56b9e13d29c61..7e585424cb2f4 100644 --- a/airflow-core/src/airflow/models/taskinstancehistory.py +++ b/airflow-core/src/airflow/models/taskinstancehistory.py @@ -25,6 +25,7 @@ DateTime, Float, ForeignKeyConstraint, + Index, Integer, String, UniqueConstraint, @@ -32,7 +33,6 @@ select, text, ) -from sqlalchemy.dialects import postgresql from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import relationship from sqlalchemy_utils import UUIDType @@ -62,11 +62,7 @@ class TaskInstanceHistory(Base): """ __tablename__ = "task_instance_history" - try_id = Column(UUIDType(binary=False), nullable=False, primary_key=True) - task_instance_id = Column( - String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), - nullable=False, - ) + task_instance_id = Column(UUIDType(binary=False), nullable=False, primary_key=True) task_id = Column(StringID(), nullable=False) dag_id = Column(StringID(), nullable=False) run_id = Column(StringID(), nullable=False) @@ -150,6 +146,7 @@ def __init__( "try_number", name="task_instance_history_dtrt_uq", ), + Index("idx_tih_dag_run", dag_id, run_id), ) @staticmethod diff --git a/airflow-core/src/airflow/models/taskreschedule.py b/airflow-core/src/airflow/models/taskreschedule.py index f82b95537ce83..e07d750ebdef7 100644 --- a/airflow-core/src/airflow/models/taskreschedule.py +++ b/airflow-core/src/airflow/models/taskreschedule.py @@ -55,7 +55,6 @@ class TaskReschedule(Base): ForeignKey("task_instance.id", ondelete="CASCADE", name="task_reschedule_ti_fkey"), nullable=False, ) - try_number = Column(Integer, nullable=False) start_date = Column(UtcDateTime, nullable=False) end_date = Column(UtcDateTime, nullable=False) duration = Column(Integer, nullable=False) @@ -67,14 +66,12 @@ class TaskReschedule(Base): def __init__( self, - task_instance_id: uuid.UUID, - try_number: int, + ti_id: uuid.UUID, start_date: datetime.datetime, end_date: datetime.datetime, reschedule_date: datetime.datetime, ) -> None: - self.ti_id = task_instance_id - self.try_number = try_number + self.ti_id = ti_id self.start_date = start_date self.end_date = end_date self.reschedule_date = reschedule_date @@ -85,7 +82,6 @@ def stmt_for_task_instance( cls, ti: TaskInstance, *, - try_number: int | None = None, descending: bool = False, ) -> Select: """ @@ -93,14 +89,6 @@ def stmt_for_task_instance( :param ti: the task instance to find task reschedules for :param descending: If True then records are returned in descending order - :param try_number: Look for TaskReschedule of the given try_number. Default is None which - looks for the same try_number of the given task_instance. :meta private: """ - if try_number is None: - try_number = ti.try_number - return ( - select(cls) - .where(cls.ti_id == ti.id, cls.try_number == try_number) - .order_by(desc(cls.id) if descending else asc(cls.id)) - ) + return select(cls).where(cls.ti_id == ti.id).order_by(desc(cls.id) if descending else asc(cls.id)) diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index a563bbebb0af9..e6104a7f44330 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -92,7 +92,7 @@ class MappedClassProtocol(Protocol): "2.9.2": "686269002441", "2.10.0": "22ed7efa9da2", "2.10.3": "5f2621c13b39", - "3.0.0": "959e216a3abb", + "3.0.0": "29ce7909c52b", } diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py index 70b54c31dba25..690b0f7bd7642 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py @@ -189,7 +189,13 @@ def time_freezer(self) -> Generator: freezer.stop() @pytest.fixture(autouse=True) - def setup(self) -> None: + def setup(self): + clear_db_assets() + clear_db_runs() + clear_db_logs() + + yield + clear_db_assets() clear_db_runs() clear_db_logs() @@ -229,7 +235,7 @@ def create_asset_dag_run(self, session, num: int = 2): class TestGetAssets(TestAssets): def test_should_respond_200(self, test_client, session): - self.create_assets() + assets1, asset2 = self.create_assets(session) session.add(AssetModel("inactive", "inactive")) session.commit() @@ -243,7 +249,7 @@ def test_should_respond_200(self, test_client, session): assert response_data == { "assets": [ { - "id": 1, + "id": assets1.id, "name": "simple1", "uri": "s3://bucket/key/1", "group": "asset", @@ -255,7 +261,7 @@ def test_should_respond_200(self, test_client, session): "aliases": [], }, { - "id": 2, + "id": asset2.id, "name": "simple2", "uri": "s3://bucket/key/2", "group": "asset", @@ -271,10 +277,9 @@ def test_should_respond_200(self, test_client, session): } def test_should_show_inactive(self, test_client, session): - self.create_assets() + asset1, asset2 = self.create_assets(session) session.add( - AssetModel( - id=3, + asset3 := AssetModel( name="simple3", uri="s3://bucket/key/3", group="asset", @@ -295,7 +300,7 @@ def test_should_show_inactive(self, test_client, session): assert response_data == { "assets": [ { - "id": 1, + "id": asset1.id, "name": "simple1", "uri": "s3://bucket/key/1", "group": "asset", @@ -307,7 +312,7 @@ def test_should_show_inactive(self, test_client, session): "aliases": [], }, { - "id": 2, + "id": asset2.id, "name": "simple2", "uri": "s3://bucket/key/2", "group": "asset", @@ -319,7 +324,7 @@ def test_should_show_inactive(self, test_client, session): "aliases": [], }, { - "id": 3, + "id": asset3.id, "name": "simple3", "uri": "s3://bucket/key/3", "group": "asset", @@ -624,11 +629,12 @@ def test_should_respect_page_size_limit_default(self, test_client): class TestGetAssetEvents(TestAssets): def test_should_respond_200(self, test_client, session): - self.create_assets() - self.create_assets_events() - self.create_dag_run() - self.create_asset_dag_run() + asset1, asset2 = self.create_assets(session) + self.create_assets_events(session) + self.create_dag_run(session) + self.create_asset_dag_run(session) assets = session.query(AssetEvent).all() + session.commit() assert len(assets) == 2 response = test_client.get("/assets/events") assert response.status_code == 200 @@ -944,6 +950,8 @@ def test_should_respond_404(self, test_client): class TestQueuedEventEndpoint(TestAssets): def _create_asset_dag_run_queues(self, dag_id, asset_id, session): + session.query(AssetDagRunQueue).delete() + session.flush() adrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) session.add(adrq) session.commit() @@ -955,9 +963,8 @@ class TestGetDagAssetQueuedEvents(TestQueuedEventEndpoint): def test_should_respond_200(self, test_client, session, create_dummy_dag): dag, _ = create_dummy_dag() dag_id = dag.dag_id - self.create_assets(session=session, num=1) - asset_id = 1 - self._create_asset_dag_run_queues(dag_id, asset_id, session) + (asset,) = self.create_assets(session=session, num=1) + self._create_asset_dag_run_queues(dag_id, asset.id, session) response = test_client.get( f"/dags/{dag_id}/assets/queuedEvents", @@ -967,7 +974,7 @@ def test_should_respond_200(self, test_client, session, create_dummy_dag): assert response.json() == { "queued_events": [ { - "asset_id": 1, + "asset_id": asset.id, "dag_id": "dag", "created_at": from_datetime_to_zulu_without_ms(DEFAULT_DATE), } @@ -1050,13 +1057,13 @@ def test_should_respond_404_valid_dag_no_adrq(self, test_client, session, create class TestPostAssetEvents(TestAssets): @pytest.mark.usefixtures("time_freezer") def test_should_respond_200(self, test_client, session): - self.create_assets(session) - event_payload = {"asset_id": 1, "extra": {"foo": "bar"}} + (asset,) = self.create_assets(session, num=1) + event_payload = {"asset_id": asset.id, "extra": {"foo": "bar"}} response = test_client.post("/assets/events", json=event_payload) assert response.status_code == 200 assert response.json() == { "id": mock.ANY, - "asset_id": 1, + "asset_id": asset.id, "uri": "s3://bucket/key/1", "group": "asset", "name": "simple1", @@ -1088,13 +1095,13 @@ def test_invalid_attr_not_allowed(self, test_client, session): @pytest.mark.usefixtures("time_freezer") @pytest.mark.enable_redact def test_should_mask_sensitive_extra(self, test_client, session): - self.create_assets(session) - event_payload = {"asset_id": 1, "extra": {"password": "bar"}} + (asset,) = self.create_assets(session, num=1) + event_payload = {"asset_id": asset.id, "extra": {"password": "bar"}} response = test_client.post("/assets/events", json=event_payload) assert response.status_code == 200 assert response.json() == { "id": mock.ANY, - "asset_id": 1, + "asset_id": asset.id, "uri": "s3://bucket/key/1", "group": "asset", "name": "simple1", @@ -1118,7 +1125,9 @@ class TestPostAssetMaterialize(TestAssets): @pytest.fixture(autouse=True) def create_dags(self, setup, dag_maker, session): # Depend on 'setup' so it runs first. Otherwise it deletes what we create here. - assets = {am.id: am.to_public() for am in self.create_assets(session=session, num=3)} + assets = { + i: am.to_public() for i, am in enumerate(self.create_assets(session=session, num=3), start=1) + } with dag_maker(self.DAG_ASSET1_ID, schedule=None, session=session): EmptyOperator(task_id="task", outlets=assets[1]) with dag_maker(self.DAG_ASSET2_ID_A, schedule=None, session=session): @@ -1176,16 +1185,15 @@ class TestGetAssetQueuedEvents(TestQueuedEventEndpoint): def test_should_respond_200(self, test_client, session, create_dummy_dag): dag, _ = create_dummy_dag() dag_id = dag.dag_id - self.create_assets(session=session, num=1) - asset_id = 1 - self._create_asset_dag_run_queues(dag_id, asset_id, session) + (asset,) = self.create_assets(session=session, num=1) + self._create_asset_dag_run_queues(dag_id, asset.id, session) - response = test_client.get(f"/assets/{asset_id}/queuedEvents") + response = test_client.get(f"/assets/{asset.id}/queuedEvents") assert response.status_code == 200 assert response.json() == { "queued_events": [ { - "asset_id": asset_id, + "asset_id": asset.id, "dag_id": "dag", "created_at": from_datetime_to_zulu_without_ms(DEFAULT_DATE), } @@ -1212,14 +1220,13 @@ class TestDeleteAssetQueuedEvents(TestQueuedEventEndpoint): def test_should_respond_204(self, test_client, session, create_dummy_dag): dag, _ = create_dummy_dag() dag_id = dag.dag_id - self.create_assets(session=session, num=1) - asset_id = 1 - self._create_asset_dag_run_queues(dag_id, asset_id, session) + (asset,) = self.create_assets(session=session, num=1) + self._create_asset_dag_run_queues(dag_id, asset.id, session) - assert session.get(AssetDagRunQueue, (asset_id, dag_id)) is not None - response = test_client.delete(f"/assets/{asset_id}/queuedEvents") + assert session.get(AssetDagRunQueue, (asset.id, dag_id)) is not None + response = test_client.delete(f"/assets/{asset.id}/queuedEvents") assert response.status_code == 204 - assert session.get(AssetDagRunQueue, (asset_id, dag_id)) is None + assert session.get(AssetDagRunQueue, (asset.id, dag_id)) is None check_last_log(session, dag_id=None, event="delete_asset_queued_events", logical_date=None) def test_should_respond_401(self, unauthenticated_test_client): @@ -1240,15 +1247,14 @@ class TestDeleteDagAssetQueuedEvent(TestQueuedEventEndpoint): def test_delete_should_respond_204(self, test_client, session, create_dummy_dag): dag, _ = create_dummy_dag() dag_id = dag.dag_id - self.create_assets(session=session, num=1) - asset_id = 1 + (asset,) = self.create_assets(session=session, num=1) - self._create_asset_dag_run_queues(dag_id, asset_id, session) + self._create_asset_dag_run_queues(dag_id, asset.id, session) adrq = session.query(AssetDagRunQueue).all() assert len(adrq) == 1 response = test_client.delete( - f"/dags/{dag_id}/assets/{asset_id}/queuedEvents", + f"/dags/{dag_id}/assets/{asset.id}/queuedEvents", ) assert response.status_code == 204 diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py index 11f3846cacd78..b391fc66492b7 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py @@ -25,6 +25,7 @@ import pytest from itsdangerous.url_safe import URLSafeSerializer +from uuid6 import uuid7 from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.models.dag import DAG @@ -72,6 +73,18 @@ def add_one(x: int): self.app.state.dag_bag.bag_dag(dag) + for ti in dr.task_instances: + ti.try_number = 1 + ti.hostname = "localhost" + session.merge(ti) + dag.clear() + for ti in dr.task_instances: + ti.try_number = 2 + ti.id = str(uuid7()) + ti.hostname = "localhost" + session.merge(ti) + session.flush() + # Add dummy dag for checking picking correct log with same task_id and different dag_id case. with dag_maker( f"{self.DAG_ID}_copy", start_date=timezone.parse(self.default_time), session=session @@ -85,27 +98,21 @@ def add_one(x: int): ) self.app.state.dag_bag.bag_dag(dummy_dag) - for ti in dr.task_instances: - ti.try_number = 1 - ti.hostname = "localhost" - session.merge(ti) for ti in dr2.task_instances: ti.try_number = 1 ti.hostname = "localhost" session.merge(ti) - session.flush() - dag.clear() dummy_dag.clear() - for ti in dr.task_instances: - ti.try_number = 2 - ti.hostname = "localhost" - session.merge(ti) for ti in dr2.task_instances: ti.try_number = 2 + ti.id = str(uuid7()) ti.hostname = "localhost" session.merge(ti) + session.flush() session.flush() + ... + @pytest.fixture def configure_loggers(self, tmp_path, create_log_template): self.log_dir = tmp_path @@ -165,6 +172,7 @@ def test_should_respond_200_json(self, try_number): ) expected_filename = f"{self.log_dir}/dag_id={self.DAG_ID}/run_id={self.RUN_ID}/task_id={self.TASK_ID}/attempt={try_number}.log" log_content = "Log for testing." if try_number == 1 else "Log for testing 2." + assert response.status_code == 200, response.json() resp_contnt = response.json()["content"] assert expected_filename in resp_contnt[0]["sources"] assert log_content in resp_contnt[2]["event"] diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 4678e6c30eaf1..7b7b3e531624d 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -96,7 +96,7 @@ def create_task_instances( dag_id: str = "example_python_operator", update_extras: bool = True, task_instances=None, - dag_run_state=State.RUNNING, + dag_run_state=DagRunState.RUNNING, with_ti_history=False, ): """Method to create task instances using kwargs and default arguments""" @@ -133,6 +133,7 @@ def create_task_instances( state=dag_run_state, ) session.add(dr) + session.flush() ti = TaskInstance(task=tasks[i], **self.ti_init) session.add(ti) ti.dag_run = dr @@ -142,18 +143,20 @@ def create_task_instances( setattr(ti, key, value) tis.append(ti) - session.commit() + session.flush() + if with_ti_history: for ti in tis: ti.try_number = 1 session.merge(ti) - session.commit() - dag.clear() + session.flush() + dag.clear(session=session) for ti in tis: ti.try_number = 2 ti.queue = "default_queue" session.merge(ti) - session.commit() + session.flush() + session.commit() return tis diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 456c3f9fa2a48..c933c0091d53f 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -707,7 +707,7 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan assert trs[0].task_instance.dag_id == "dag" assert trs[0].task_instance.task_id == "test_ti_update_state_to_reschedule" assert trs[0].task_instance.run_id == "test" - assert trs[0].try_number == 0 + assert trs[0].ti_id == tis[0].id assert trs[0].start_date == instant assert trs[0].end_date == DEFAULT_END_DATE assert trs[0].reschedule_date == timezone.parse("2024-10-31T11:03:00+00:00") @@ -763,20 +763,21 @@ def test_ti_update_state_handle_retry(self, client, session, create_task_instanc assert response.status_code == 204 assert response.text == "" - session.expire_all() - - ti = session.get(TaskInstance, ti.id) + ti = session.scalar( + select(TaskInstance).filter_by(task_id=ti.task_id, run_id=ti.run_id, dag_id=ti.dag_id) + ) + # ti = session.get(TaskInstance, ti.id) assert ti.state == State.UP_FOR_RETRY assert ti.next_method is None assert ti.next_kwargs is None tih = ( session.query(TaskInstanceHistory) - .where(TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.task_instance_id == ti.id) + .where(TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.run_id == ti.run_id) .one() ) - assert tih.try_id - assert tih.try_id != ti.try_id + assert tih.task_instance_id + assert tih.task_instance_id != ti.id def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance): # we just want to fail in this test, no need to retry @@ -1204,8 +1205,7 @@ def test_get_start_date(self, client, session, create_task_instance): session=session, ) tr = TaskReschedule( - task_instance_id=ti.id, - try_number=1, + ti_id=ti.id, start_date=timezone.datetime(2024, 1, 1), end_date=timezone.datetime(2024, 1, 1, 1), reschedule_date=timezone.datetime(2024, 1, 1, 2), @@ -1222,37 +1222,6 @@ def test_get_start_date_not_found(self, client): response = client.get(f"/execution/task-reschedules/{ti_id}/start_date") assert response.json() is None - def test_get_start_date_with_try_number(self, client, session, create_task_instance): - # Create multiple reschedules - dates = [ - timezone.datetime(2024, 1, 1), - timezone.datetime(2024, 1, 2), - timezone.datetime(2024, 1, 3), - ] - - ti = create_task_instance( - task_id="test_get_start_date_with_try_number", - state=State.RUNNING, - start_date=timezone.datetime(2024, 1, 1), - session=session, - ) - - for i, date in enumerate(dates, 1): - tr = TaskReschedule( - task_instance_id=ti.id, - try_number=i, - start_date=date, - end_date=date.replace(hour=1), - reschedule_date=date.replace(hour=2), - ) - session.add(tr) - session.commit() - - # Test getting start date for try_number 2 - response = client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2") - assert response.status_code == 200 - assert response.json() == "2024-01-02T00:00:00Z" - class TestGetCount: def setup_method(self): diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 7e98c1683fad5..b1119628e22a7 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -1711,19 +1711,21 @@ def _create_dagruns(): # First pass we'll grab 6 of the 8 tasks (limited by max_tis_per_query) res = self.job_runner._critical_section_enqueue_task_instances(session) assert res == 6 + session.flush() for ti in tis1[:3] + tis2[:3]: - ti.refresh_from_db() + ti.refresh_from_db(session) assert ti.state == TaskInstanceState.QUEUED for ti in tis1[3:] + tis2[3:]: - ti.refresh_from_db() + ti.refresh_from_db(session) assert ti.state == TaskInstanceState.SCHEDULED # The remaining TIs are queued res = self.job_runner._critical_section_enqueue_task_instances(session) assert res == 2 + session.flush() for ti in tis1 + tis2: - ti.refresh_from_db() + ti.refresh_from_db(session) assert ti.state == State.QUEUED @pytest.mark.parametrize( @@ -2555,8 +2557,9 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak session = settings.Session() dr = dag_maker.create_dagrun() - ti = dr.get_task_instance("dummy") + ti = dr.get_task_instance("dummy", session) ti.set_state(state, session) + session.flush() with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): self.job_runner._do_scheduling(session) @@ -2579,7 +2582,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak @pytest.mark.parametrize( "state, expected_callback_msg", [(State.SUCCESS, "success"), (State.FAILED, "task_failure")] ) - def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_maker): + def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_maker, session): """ Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor. """ @@ -2587,6 +2590,7 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak dag_id="test_dagrun_callbacks_are_called", on_success_callback=lambda x: print("success"), on_failure_callback=lambda x: print("failed"), + session=session, ): EmptyOperator(task_id="dummy") @@ -2598,10 +2602,9 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak self.job_runner.dagbag = dag_maker.dagbag - session = settings.Session() dr = dag_maker.create_dagrun() - ti = dr.get_task_instance("dummy") + ti = dr.get_task_instance("dummy", session) ti.set_state(state, session) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): @@ -2611,8 +2614,8 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak dag_listener.success = [] dag_listener.failure = [] + session.rollback() - session.close() def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, session): with dag_maker( @@ -2700,12 +2703,13 @@ def mock_send_dag_callbacks_to_processor(*args, **kwargs): session.close() @pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED]) - def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state, dag_maker): + def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state, dag_maker, session): """ Test if no on_*_callback are defined on DAG, Callbacks not registered and sent to DAG Processor """ with dag_maker( dag_id="test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined", + session=session, ): BashOperator(task_id="test_task", bash_command="echo hi") @@ -2714,9 +2718,8 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta self.job_runner._send_dag_callbacks_to_processor = mock.Mock() - session = settings.Session() dr = dag_maker.create_dagrun() - ti = dr.get_task_instance("test_task") + ti = dr.get_task_instance("test_task", session) ti.set_state(state, session) with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): @@ -2729,7 +2732,6 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta assert call_args[1] is None session.rollback() - session.close() @pytest.mark.parametrize("state, msg", [[State.SUCCESS, "success"], [State.FAILED, "task_failure"]]) def test_dagrun_callbacks_are_added_when_callbacks_are_defined(self, state, msg, dag_maker): @@ -3483,10 +3485,10 @@ def run_with_error(ti, ignore_ti_state=False): assert ti.state == State.UP_FOR_RETRY assert ti.try_number == 1 - with create_session() as session: - ti.refresh_from_db(lock_for_update=True, session=session) - ti.state = State.SCHEDULED - session.merge(ti) + ti.refresh_from_db(lock_for_update=True, session=session) + ti.state = State.SCHEDULED + session.merge(ti) + session.commit() # To verify that task does get re-queued. executor.do_update = True @@ -6223,12 +6225,10 @@ def test_update_dagrun_state_for_paused_dag(self, dag_maker, session): run_type=DagRunType.SCHEDULED, ) scheduled_run.last_scheduling_decision = datetime.datetime.now(timezone.utc) - timedelta(minutes=1) - ti = scheduled_run.get_task_instances()[0] + ti = scheduled_run.get_task_instances(session=session)[0] ti.set_state(TaskInstanceState.RUNNING) - dm = DagModel.get_dagmodel(dag.dag_id) + dm = DagModel.get_dagmodel(dag.dag_id, session) dm.is_paused = True - session.merge(dm) - session.merge(ti) session.flush() assert scheduled_run.state == State.RUNNING @@ -6252,7 +6252,8 @@ def test_update_dagrun_state_for_paused_dag(self, dag_maker, session): assert prior_last_scheduling_decision == scheduled_run.last_scheduling_decision # Once the TI is in a terminal state though, DagRun goes to success - ti.set_state(TaskInstanceState.SUCCESS) + ti.set_state(TaskInstanceState.SUCCESS, session=session) + self.job_runner._update_dag_run_state_for_paused_dags(session=session) (scheduled_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.SCHEDULED, session=session) assert scheduled_run.state == State.SUCCESS @@ -6265,12 +6266,10 @@ def test_update_dagrun_state_for_paused_dag_not_for_backfill(self, dag_maker, se # Backfill run backfill_run = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB) backfill_run.last_scheduling_decision = datetime.datetime.now(timezone.utc) - timedelta(minutes=1) - ti = backfill_run.get_task_instances()[0] - ti.set_state(TaskInstanceState.SUCCESS) - dm = DagModel.get_dagmodel(dag.dag_id) + ti = backfill_run.get_task_instances(session=session)[0] + ti.set_state(TaskInstanceState.SUCCESS, session=session) + dm = DagModel.get_dagmodel(dag.dag_id, session=session) dm.is_paused = True - session.merge(dm) - session.merge(ti) session.flush() assert backfill_run.state == State.RUNNING @@ -6278,7 +6277,7 @@ def test_update_dagrun_state_for_paused_dag_not_for_backfill(self, dag_maker, se scheduler_job = Job(executor=self.null_exec) self.job_runner = SchedulerJobRunner(job=scheduler_job) - self.job_runner._update_dag_run_state_for_paused_dags() + self.job_runner._update_dag_run_state_for_paused_dags(session=session) session.flush() (backfill_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.BACKFILL_JOB, session=session) diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index f6626c3d7816b..9bcae8cceba51 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -76,8 +76,8 @@ def test_clear_task_instances(self, dag_maker): # do the incrementing of try_number ordinarily handled by scheduler ti0.try_number += 1 ti1.try_number += 1 - session.merge(ti0) - session.merge(ti1) + ti0 = session.merge(ti0) + ti1 = session.merge(ti1) session.commit() # we use order_by(task_id) here because for the test DAG structure of ours @@ -87,8 +87,9 @@ def test_clear_task_instances(self, dag_maker): qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() clear_task_instances(qry, session, dag=dag) - ti0.refresh_from_db() - ti1.refresh_from_db() + ti0.refresh_from_db(session) + ti1.refresh_from_db(session) + # Next try to run will be try 2 assert ti0.state is None assert ti0.try_number == 1 @@ -530,11 +531,7 @@ def test_clear_task_instances_with_task_reschedule(self, dag_maker): with create_session() as session: def count_task_reschedule(ti): - return ( - session.query(TaskReschedule) - .filter(TaskReschedule.ti_id == ti.id, TaskReschedule.try_number == 1) - .count() - ) + return session.query(TaskReschedule).filter(TaskReschedule.ti_id == ti.id).count() assert count_task_reschedule(ti0) == 1 assert count_task_reschedule(ti1) == 1 @@ -626,12 +623,14 @@ def test_dag_clear(self, dag_maker): assert ti0.try_number == 1 dag.clear() ti0.refresh_from_db() + ti1.refresh_from_db() assert ti0.try_number == 1 assert ti0.state == State.NONE assert ti0.max_tries == 1 assert ti1.max_tries == 2 - session.get(TaskInstance, ti1.id).try_number += 1 + session.add(ti1) + ti1.try_number += 1 session.commit() ti1.run() @@ -747,29 +746,37 @@ def test_operator_clear(self, dag_maker, session): run_type=DagRunType.SCHEDULED, ) - ti1, ti2 = sorted(dr.task_instances, key=lambda ti: ti.task_id) + ti1, ti2 = sorted(dr.get_task_instances(session=session), key=lambda ti: ti.task_id) ti1.task = op1 ti2.task = op2 session.get(TaskInstance, ti2.id).try_number += 1 session.commit() - ti2.run() + ti2.run(session=session) # Dependency not met assert ti2.try_number == 1 assert ti2.max_tries == 1 - op2.clear(upstream=True) + op2.clear(upstream=True, session=session) + ti1.refresh_from_db(session) + ti2.refresh_from_db(session) # max tries will be set to retries + curr try number == 1 + 1 == 2 - assert session.get(TaskInstance, ti2.id).max_tries == 2 + assert ti2.max_tries == 2 - session.get(TaskInstance, ti1.id).try_number += 1 + ti1.try_number += 1 + session.merge(ti1) session.commit() - ti1.run() + + ti1.run(session=session) + ti1.refresh_from_db(session) + ti2.refresh_from_db(session) assert ti1.try_number == 1 - session.get(TaskInstance, ti2.id).try_number += 1 - session.commit() - ti2.run(ignore_ti_state=True) + ti2.try_number += 1 + session.add(ti2) + session.flush() + ti2.run(ignore_ti_state=True, session=session) + ti2.refresh_from_db(session) # max_tries is 0 because there is no task instance in db for ti1 # so clear won't change the max_tries. assert ti1.max_tries == 0 diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index dacfcf7a75a22..35f4ec7964850 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -2403,8 +2403,7 @@ def printx(x): ti = session.query(TaskInstance).filter_by(**filter_kwargs).one() tr = TaskReschedule( - task_instance_id=ti.id, - try_number=ti.try_number, + ti_id=ti.id, start_date=timezone.datetime(2017, 1, 1), end_date=timezone.datetime(2017, 1, 2), reschedule_date=timezone.datetime(2017, 1, 1), diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 1fc9de2a3b600..acbcb1939fd7a 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -724,11 +724,14 @@ def run_with_error(ti): # clearing it first dag.clear() - session.get(TaskInstance, ti.id).try_number += 1 + ti.refresh_from_db(session) + ti.try_number += 1 + session.add(ti) session.commit() # third run -- up for retry run_with_error(ti) + ti.refresh_from_db() assert ti.state == State.UP_FOR_RETRY assert ti.try_number == 3 @@ -901,7 +904,7 @@ def run_ti_and_assert( run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1) done, fail = False, True - run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 1) + run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 0) # scheduler would create a new try here with create_session() as session: @@ -1005,7 +1008,7 @@ def run_ti_and_assert( run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1) done, fail = False, True - run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 1) + run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 0) with create_session() as session: session.get(TaskInstance, ti.id).try_number += 1 @@ -1041,7 +1044,6 @@ def func(): ).expand(poke_interval=[0]) ti = dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0] ti.task = task - assert ti.try_number == 0 def run_ti_and_assert( run_date, @@ -3537,8 +3539,7 @@ def test_handle_failure_task_undefined(self, create_task_instance): del ti.task ti.handle_failure("test ti.task undefined") - @provide_session - def test_handle_failure_fail_fast(self, create_dummy_dag, session=None): + def test_handle_failure_fail_fast(self, create_dummy_dag, session): start_date = timezone.datetime(2016, 6, 1) clear_db_runs() @@ -3588,7 +3589,7 @@ def test_handle_failure_fail_fast(self, create_dummy_dag, session=None): ti_ff = TI(task=fail_task, run_id=dr.run_id) ti_ff.state = State.FAILED session.add(ti_ff) - session.flush() + session.commit() ti_ff.handle_failure("test retry handling") assert ti1.state == State.SUCCESS @@ -3880,7 +3881,6 @@ def test_refresh_from_db(self, create_task_instance): "end_date": run_date + datetime.timedelta(days=1, seconds=1, milliseconds=234), "duration": 1.234, "state": State.SUCCESS, - "try_id": mock.ANY, "try_number": 1, "max_tries": 1, "hostname": "some_unique_hostname", @@ -4023,19 +4023,18 @@ def test_task_instance_history_is_created_when_ti_goes_for_retry(self, dag_maker dr = dag_maker.create_dagrun() ti = dr.task_instances[0] ti.task = task - try_id = ti.try_id + try_id = ti.id with pytest.raises(AirflowException): ti.run() ti = session.query(TaskInstance).one() - # the ti.try_id should be different from the previous one - assert ti.try_id != try_id + # the ti.id should be different from the previous one + assert ti.id != try_id assert ti.state == State.UP_FOR_RETRY assert session.query(TaskInstance).count() == 1 tih = session.query(TaskInstanceHistory).all() assert len(tih) == 1 # the new try_id should be different from what's recorded in tih - assert tih[0].try_id == try_id - assert tih[0].try_id != ti.try_id + assert str(tih[0].task_instance_id) == try_id @pytest.mark.skip( reason="This test has some issues that were surfaced when dag_maker started allowing multiple serdag versions. Issue #48539 will track fixing this." diff --git a/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py index c56d0b84b3553..7b4bd22aa0436 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -88,8 +88,7 @@ def _create_task_reschedule(self, ti, minutes: int | list[int]): dt = ti.logical_date + timedelta(minutes=minutes_timedelta) trs.append( TaskReschedule( - task_instance_id=ti.id, - try_number=ti.try_number, + ti_id=ti.id, start_date=dt, end_date=dt, reschedule_date=dt, diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 83197298733ce..4261ae80cd0ce 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -78,6 +78,7 @@ def sort_by(item): state = self.mock_task_results[key] ti.set_state(state, session=session) self.change_state(key, state) + session.flush() def terminate(self): pass diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 20636cd07059f..5d0c99dc354bb 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -480,6 +480,7 @@ def f(): branch_ti.set_state(TaskInstanceState.SUCCESS, session=session) dr.task_instance_scheduling_decisions(session=session) branch_2_ti = dr.get_task_instance(task_id="branch_2", session=session) + branch_2_ti.task = self.branch_2 assert branch_2_ti.state == TaskInstanceState.SKIPPED branch_2_ti.set_state(None) branch_2_ti.run() @@ -761,6 +762,7 @@ def test_clear_skipped_downstream_task(self): sc_ti.set_state(TaskInstanceState.SUCCESS, session=session) dr.task_instance_scheduling_decisions(session=session) op1_ti = dr.get_task_instance(task_id="op1", session=session) + op1_ti.task = self.op1 assert op1_ti.state == TaskInstanceState.SKIPPED op1_ti.set_state(None) op1_ti.run() @@ -1709,6 +1711,7 @@ def f(): branch_2_ti = dr.get_task_instance(task_id="branch_2", session=session) # FIXME if self.opcls != BranchExternalPythonOperator: + branch_2_ti.task = self.branch_2 assert branch_2_ti.state == TaskInstanceState.SKIPPED branch_2_ti.set_state(None, session=session) branch_2_ti.run()