diff --git a/airflow-core/docs/img/airflow_erd.sha256 b/airflow-core/docs/img/airflow_erd.sha256
index ec5a663b3429b..fb89acbe007e2 100644
--- a/airflow-core/docs/img/airflow_erd.sha256
+++ b/airflow-core/docs/img/airflow_erd.sha256
@@ -1 +1 @@
-20aac73bfb0e08226c687304b997ed93d9cf796d3b761d14e45d732ef4f6c01e
\ No newline at end of file
+db00d57fce32830b69f2c1481b231e65e67e197b4a96a5fa1c870cd555eac3bd
\ No newline at end of file
diff --git a/airflow-core/docs/img/airflow_erd.svg b/airflow-core/docs/img/airflow_erd.svg
index dbcc7d4060ec2..7398a42070836 100644
--- a/airflow-core/docs/img/airflow_erd.svg
+++ b/airflow-core/docs/img/airflow_erd.svg
@@ -714,169 +714,164 @@
task_instance
-
-task_instance
-
-id
-
- [UUID]
- NOT NULL
-
-context_carrier
-
- [JSONB]
-
-custom_operator_name
-
- [VARCHAR(1000)]
-
-dag_id
-
- [VARCHAR(250)]
- NOT NULL
-
-dag_version_id
-
- [UUID]
-
-duration
-
- [DOUBLE_PRECISION]
-
-end_date
-
- [TIMESTAMP]
-
-executor
-
- [VARCHAR(1000)]
-
-executor_config
-
- [BYTEA]
-
-external_executor_id
-
- [VARCHAR(250)]
-
-hostname
-
- [VARCHAR(1000)]
-
-last_heartbeat_at
-
- [TIMESTAMP]
-
-map_index
-
- [INTEGER]
- NOT NULL
-
-max_tries
-
- [INTEGER]
-
-next_kwargs
-
- [JSONB]
-
-next_method
-
- [VARCHAR(1000)]
-
-operator
-
- [VARCHAR(1000)]
-
-pid
-
- [INTEGER]
-
-pool
-
- [VARCHAR(256)]
- NOT NULL
-
-pool_slots
-
- [INTEGER]
- NOT NULL
-
-priority_weight
-
- [INTEGER]
-
-queue
-
- [VARCHAR(256)]
-
-queued_by_job_id
-
- [INTEGER]
-
-queued_dttm
-
- [TIMESTAMP]
-
-rendered_map_index
-
- [VARCHAR(250)]
-
-run_id
-
- [VARCHAR(250)]
- NOT NULL
-
-scheduled_dttm
-
- [TIMESTAMP]
-
-span_status
-
- [VARCHAR(250)]
- NOT NULL
-
-start_date
-
- [TIMESTAMP]
-
-state
-
- [VARCHAR(20)]
-
-task_display_name
-
- [VARCHAR(2000)]
-
-task_id
-
- [VARCHAR(250)]
- NOT NULL
-
-trigger_id
-
- [INTEGER]
-
-trigger_timeout
-
- [TIMESTAMP]
-
-try_id
-
- [UUID]
- NOT NULL
-
-try_number
-
- [INTEGER]
-
-unixname
-
- [VARCHAR(1000)]
-
-updated_at
-
- [TIMESTAMP]
+
+task_instance
+
+id
+
+ [UUID]
+ NOT NULL
+
+context_carrier
+
+ [JSONB]
+
+custom_operator_name
+
+ [VARCHAR(1000)]
+
+dag_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+dag_version_id
+
+ [UUID]
+
+duration
+
+ [DOUBLE_PRECISION]
+
+end_date
+
+ [TIMESTAMP]
+
+executor
+
+ [VARCHAR(1000)]
+
+executor_config
+
+ [BYTEA]
+
+external_executor_id
+
+ [VARCHAR(250)]
+
+hostname
+
+ [VARCHAR(1000)]
+
+last_heartbeat_at
+
+ [TIMESTAMP]
+
+map_index
+
+ [INTEGER]
+ NOT NULL
+
+max_tries
+
+ [INTEGER]
+
+next_kwargs
+
+ [JSONB]
+
+next_method
+
+ [VARCHAR(1000)]
+
+operator
+
+ [VARCHAR(1000)]
+
+pid
+
+ [INTEGER]
+
+pool
+
+ [VARCHAR(256)]
+ NOT NULL
+
+pool_slots
+
+ [INTEGER]
+ NOT NULL
+
+priority_weight
+
+ [INTEGER]
+
+queue
+
+ [VARCHAR(256)]
+
+queued_by_job_id
+
+ [INTEGER]
+
+queued_dttm
+
+ [TIMESTAMP]
+
+rendered_map_index
+
+ [VARCHAR(250)]
+
+run_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+scheduled_dttm
+
+ [TIMESTAMP]
+
+span_status
+
+ [VARCHAR(250)]
+ NOT NULL
+
+start_date
+
+ [TIMESTAMP]
+
+state
+
+ [VARCHAR(20)]
+
+task_display_name
+
+ [VARCHAR(2000)]
+
+task_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+trigger_id
+
+ [INTEGER]
+
+trigger_timeout
+
+ [TIMESTAMP]
+
+try_number
+
+ [INTEGER]
+
+unixname
+
+ [VARCHAR(1000)]
+
+updated_at
+
+ [TIMESTAMP]
@@ -888,477 +883,467 @@
rendered_task_instance_fields
-
-rendered_task_instance_fields
-
-dag_id
-
- [VARCHAR(250)]
- NOT NULL
-
-map_index
-
- [INTEGER]
- NOT NULL
-
-run_id
-
- [VARCHAR(250)]
- NOT NULL
-
-task_id
-
- [VARCHAR(250)]
- NOT NULL
-
-k8s_pod_yaml
-
- [JSON]
-
-rendered_fields
-
- [JSON]
- NOT NULL
+
+rendered_task_instance_fields
+
+dag_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+map_index
+
+ [INTEGER]
+ NOT NULL
+
+run_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+task_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+k8s_pod_yaml
+
+ [JSON]
+
+rendered_fields
+
+ [JSON]
+ NOT NULL
task_instance--rendered_task_instance_fields
-
-0..N
-1
+
+0..N
+1
task_instance--rendered_task_instance_fields
-
-0..N
-1
+
+0..N
+1
task_instance--rendered_task_instance_fields
-
-0..N
-1
+
+0..N
+1
task_instance--rendered_task_instance_fields
-
-0..N
-1
+
+0..N
+1
task_map
-
-task_map
-
-dag_id
-
- [VARCHAR(250)]
- NOT NULL
-
-map_index
-
- [INTEGER]
- NOT NULL
-
-run_id
-
- [VARCHAR(250)]
- NOT NULL
-
-task_id
-
- [VARCHAR(250)]
- NOT NULL
-
-keys
-
- [JSONB]
-
-length
-
- [INTEGER]
- NOT NULL
+
+task_map
+
+dag_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+map_index
+
+ [INTEGER]
+ NOT NULL
+
+run_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+task_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+keys
+
+ [JSONB]
+
+length
+
+ [INTEGER]
+ NOT NULL
task_instance--task_map
-
-0..N
-1
+
+0..N
+1
task_instance--task_map
-
-0..N
-1
+
+0..N
+1
task_instance--task_map
-
-0..N
-1
+
+0..N
+1
task_instance--task_map
-
-0..N
-1
+
+0..N
+1
task_reschedule
-
-task_reschedule
-
-id
-
- [INTEGER]
- NOT NULL
-
-duration
-
- [INTEGER]
- NOT NULL
-
-end_date
-
- [TIMESTAMP]
- NOT NULL
-
-reschedule_date
-
- [TIMESTAMP]
- NOT NULL
-
-start_date
-
- [TIMESTAMP]
- NOT NULL
-
-ti_id
-
- [UUID]
- NOT NULL
-
-try_number
-
- [INTEGER]
- NOT NULL
+
+task_reschedule
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+duration
+
+ [INTEGER]
+ NOT NULL
+
+end_date
+
+ [TIMESTAMP]
+ NOT NULL
+
+reschedule_date
+
+ [TIMESTAMP]
+ NOT NULL
+
+start_date
+
+ [TIMESTAMP]
+ NOT NULL
+
+ti_id
+
+ [UUID]
+ NOT NULL
task_instance--task_reschedule
-
-0..N
-1
+
+0..N
+1
xcom
-
-xcom
-
-dag_run_id
-
- [INTEGER]
- NOT NULL
-
-key
-
- [VARCHAR(512)]
- NOT NULL
-
-map_index
-
- [INTEGER]
- NOT NULL
-
-task_id
-
- [VARCHAR(250)]
- NOT NULL
-
-dag_id
-
- [VARCHAR(250)]
- NOT NULL
-
-run_id
-
- [VARCHAR(250)]
- NOT NULL
-
-timestamp
-
- [TIMESTAMP]
- NOT NULL
-
-value
-
- [JSONB]
+
+xcom
+
+dag_run_id
+
+ [INTEGER]
+ NOT NULL
+
+key
+
+ [VARCHAR(512)]
+ NOT NULL
+
+map_index
+
+ [INTEGER]
+ NOT NULL
+
+task_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+dag_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+run_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+timestamp
+
+ [TIMESTAMP]
+ NOT NULL
+
+value
+
+ [JSONB]
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
task_instance_note
-
-task_instance_note
-
-ti_id
-
- [UUID]
- NOT NULL
-
-content
-
- [VARCHAR(1000)]
-
-created_at
-
- [TIMESTAMP]
- NOT NULL
-
-updated_at
-
- [TIMESTAMP]
- NOT NULL
-
-user_id
-
- [VARCHAR(128)]
+
+task_instance_note
+
+ti_id
+
+ [UUID]
+ NOT NULL
+
+content
+
+ [VARCHAR(1000)]
+
+created_at
+
+ [TIMESTAMP]
+ NOT NULL
+
+updated_at
+
+ [TIMESTAMP]
+ NOT NULL
+
+user_id
+
+ [VARCHAR(128)]
task_instance--task_instance_note
-
-1
-1
+
+1
+1
task_instance_history
-
-task_instance_history
-
-try_id
-
- [UUID]
- NOT NULL
-
-context_carrier
-
- [JSONB]
-
-custom_operator_name
-
- [VARCHAR(1000)]
-
-dag_id
-
- [VARCHAR(250)]
- NOT NULL
-
-dag_version_id
-
- [UUID]
-
-duration
-
- [DOUBLE_PRECISION]
-
-end_date
-
- [TIMESTAMP]
-
-executor
-
- [VARCHAR(1000)]
-
-executor_config
-
- [BYTEA]
-
-external_executor_id
-
- [VARCHAR(250)]
-
-hostname
-
- [VARCHAR(1000)]
-
-map_index
-
- [INTEGER]
- NOT NULL
-
-max_tries
-
- [INTEGER]
-
-next_kwargs
-
- [JSONB]
-
-next_method
-
- [VARCHAR(1000)]
-
-operator
-
- [VARCHAR(1000)]
-
-pid
-
- [INTEGER]
-
-pool
-
- [VARCHAR(256)]
- NOT NULL
-
-pool_slots
-
- [INTEGER]
- NOT NULL
-
-priority_weight
-
- [INTEGER]
-
-queue
-
- [VARCHAR(256)]
-
-queued_by_job_id
-
- [INTEGER]
-
-queued_dttm
-
- [TIMESTAMP]
-
-rendered_map_index
-
- [VARCHAR(250)]
-
-run_id
-
- [VARCHAR(250)]
- NOT NULL
-
-scheduled_dttm
-
- [TIMESTAMP]
-
-span_status
-
- [VARCHAR(250)]
- NOT NULL
-
-start_date
-
- [TIMESTAMP]
-
-state
-
- [VARCHAR(20)]
-
-task_display_name
-
- [VARCHAR(2000)]
-
-task_id
-
- [VARCHAR(250)]
- NOT NULL
-
-task_instance_id
-
- [UUID]
- NOT NULL
-
-trigger_id
-
- [INTEGER]
-
-trigger_timeout
-
- [TIMESTAMP]
-
-try_number
-
- [INTEGER]
- NOT NULL
-
-unixname
-
- [VARCHAR(1000)]
-
-updated_at
-
- [TIMESTAMP]
+
+task_instance_history
+
+task_instance_id
+
+ [UUID]
+ NOT NULL
+
+context_carrier
+
+ [JSONB]
+
+custom_operator_name
+
+ [VARCHAR(1000)]
+
+dag_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+dag_version_id
+
+ [UUID]
+
+duration
+
+ [DOUBLE_PRECISION]
+
+end_date
+
+ [TIMESTAMP]
+
+executor
+
+ [VARCHAR(1000)]
+
+executor_config
+
+ [BYTEA]
+
+external_executor_id
+
+ [VARCHAR(250)]
+
+hostname
+
+ [VARCHAR(1000)]
+
+map_index
+
+ [INTEGER]
+ NOT NULL
+
+max_tries
+
+ [INTEGER]
+
+next_kwargs
+
+ [JSONB]
+
+next_method
+
+ [VARCHAR(1000)]
+
+operator
+
+ [VARCHAR(1000)]
+
+pid
+
+ [INTEGER]
+
+pool
+
+ [VARCHAR(256)]
+ NOT NULL
+
+pool_slots
+
+ [INTEGER]
+ NOT NULL
+
+priority_weight
+
+ [INTEGER]
+
+queue
+
+ [VARCHAR(256)]
+
+queued_by_job_id
+
+ [INTEGER]
+
+queued_dttm
+
+ [TIMESTAMP]
+
+rendered_map_index
+
+ [VARCHAR(250)]
+
+run_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+scheduled_dttm
+
+ [TIMESTAMP]
+
+span_status
+
+ [VARCHAR(250)]
+ NOT NULL
+
+start_date
+
+ [TIMESTAMP]
+
+state
+
+ [VARCHAR(20)]
+
+task_display_name
+
+ [VARCHAR(2000)]
+
+task_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+trigger_id
+
+ [INTEGER]
+
+trigger_timeout
+
+ [TIMESTAMP]
+
+try_number
+
+ [INTEGER]
+ NOT NULL
+
+unixname
+
+ [VARCHAR(1000)]
+
+updated_at
+
+ [TIMESTAMP]
task_instance--task_instance_history
-
-0..N
-1
+
+0..N
+1
task_instance--task_instance_history
-
-0..N
-1
+
+0..N
+1
task_instance--task_instance_history
-
-0..N
-1
+
+0..N
+1
task_instance--task_instance_history
-
-0..N
-1
+
+0..N
+1
@@ -2073,8 +2058,8 @@
dag_version--task_instance
-
-0..N
+
+0..N
{0,1}
diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst
index dbed28f94864e..2f35cf2deecf4 100644
--- a/airflow-core/docs/migrations-ref.rst
+++ b/airflow-core/docs/migrations-ref.rst
@@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
-| ``959e216a3abb`` (head) | ``0e9519b56710`` | ``3.0.0`` | Rename ``is_active`` to ``is_stale`` column in ``dag`` |
+| ``29ce7909c52b`` (head) | ``959e216a3abb`` | ``3.0.0`` | Change TI table to have unique UUID id/pk per attempt. |
++-------------------------+------------------+-------------------+--------------------------------------------------------------+
+| ``959e216a3abb`` | ``0e9519b56710`` | ``3.0.0`` | Rename ``is_active`` to ``is_stale`` column in ``dag`` |
| | | | table. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``0e9519b56710`` | ``ec62e120484d`` | ``3.0.0`` | Rename run_type from 'dataset_triggered' to |
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 6611aeb56de46..ef7a6bf6d60e9 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -213,7 +213,7 @@ def ti_run(
session.query(
func.count(TaskReschedule.id) # or any other primary key column
)
- .filter(TaskReschedule.ti_id == ti_id_str, TaskReschedule.try_number == ti.try_number)
+ .filter(TaskReschedule.ti_id == ti_id_str)
.scalar()
or 0
)
@@ -309,13 +309,9 @@ def ti_update_state(
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIRetryStatePayload):
- from airflow.models.taskinstance import uuid7
- from airflow.models.taskinstancehistory import TaskInstanceHistory
-
ti = session.get(TI, ti_id_str)
- TaskInstanceHistory.record_ti(ti, session=session)
- ti.try_id = uuid7()
updated_state = ti_patch_payload.state
+ ti.prepare_db_for_next_try(session)
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
@@ -393,7 +389,6 @@ def ti_update_state(
session.add(
TaskReschedule(
task_instance.id,
- task_instance.try_number,
actual_start_date,
ti_patch_payload.end_date,
ti_patch_payload.reschedule_date,
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
index 0c023f2e2711f..d3e940f47a08f 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py
@@ -17,10 +17,9 @@
from __future__ import annotations
-from typing import Annotated
from uuid import UUID
-from fastapi import Query, status
+from fastapi import status
from sqlalchemy import select
from airflow.api_fastapi.common.db.common import SessionDep
@@ -37,18 +36,12 @@
@router.get("/{task_instance_id}/start_date")
-def get_start_date(
- task_instance_id: UUID, session: SessionDep, try_number: Annotated[int, Query()] = 1
-) -> UtcDateTime | None:
+def get_start_date(task_instance_id: UUID, session: SessionDep) -> UtcDateTime | None:
"""Get the first reschedule date if found, None if no records exist."""
start_date = session.scalar(
- select(TaskReschedule)
- .where(
- TaskReschedule.ti_id == str(task_instance_id),
- TaskReschedule.try_number >= try_number,
- )
+ select(TaskReschedule.start_date)
+ .where(TaskReschedule.ti_id == str(task_instance_id))
.order_by(TaskReschedule.id.asc())
- .with_only_columns(TaskReschedule.start_date)
.limit(1)
)
diff --git a/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py b/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py
new file mode 100644
index 0000000000000..e0af8511ca724
--- /dev/null
+++ b/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py
@@ -0,0 +1,130 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Change TI table to have unique UUID id/pk per attempt.
+
+Revision ID: 29ce7909c52b
+Revises: 959e216a3abb
+Create Date: 2025-04-09 10:09:53.130924
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy_utils import UUIDType
+
+# revision identifiers, used by Alembic.
+revision = "29ce7909c52b"
+down_revision = "959e216a3abb"
+branch_labels = None
+depends_on = None
+airflow_version = "3.0.0"
+
+
+def _get_uuid_type(dialect_name: str) -> sa.types.TypeEngine:
+ if dialect_name == "sqlite":
+ return sa.String(36)
+ else:
+ return UUIDType(binary=False)
+
+
+def upgrade():
+ """Apply Change TI table to have unique UUID id/pk per attempt."""
+ conn = op.get_bind()
+ dialect_name = conn.dialect.name
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.drop_constraint("task_instance_try_id_uq", type_="unique")
+ batch_op.drop_column("try_id")
+
+ with op.batch_alter_table("task_instance_history", schema=None) as batch_op:
+ batch_op.create_index("idx_tih_dag_run", ["dag_id", "run_id"], unique=False)
+ batch_op.drop_column("task_instance_id")
+ batch_op.alter_column(
+ "try_id",
+ new_column_name="task_instance_id",
+ existing_type=_get_uuid_type(dialect_name),
+ existing_nullable=False,
+ )
+
+ with op.batch_alter_table("task_instance_note", schema=None) as batch_op:
+ batch_op.drop_constraint("task_instance_note_ti_fkey", type_="foreignkey")
+ batch_op.create_foreign_key(
+ "task_instance_note_ti_fkey",
+ "task_instance",
+ ["ti_id"],
+ ["id"],
+ onupdate="CASCADE",
+ ondelete="CASCADE",
+ )
+
+ # We decided to not migrate/correct the data for task_reschedule as we decided the time to do it wasn't
+ # worth it, as 90%+ of the data in the table is not needed -- this table only has use when a Sensor task,
+ # with reschedule mode, is in the `running` or `up_for_reschedule` states.
+ #
+ # Going forward, Airflow will delete rows from this table when the TI is recorded in the history table
+ # (i.e. when it's cleared or retired) so the only case in which this data will be accessed and give an
+ # incorrect figure is when all of these hold true
+ #
+ # - a Sensor in reschedule mode
+ # - in running or up_for_reschedule states
+ # - On a try_number > 1
+ # - with a timeout set
+ #
+ # If all of these are true, then the total runtime will be mistakenly calculated as the start of the first
+ # try to now, rather than the start of the current try. But this is such an unlikely set of circumstances
+ # that it's not worth the time cost of migrating it.
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
+ batch_op.drop_column("try_number")
+
+
+def downgrade():
+ """Unapply Change TI table to have unique UUID id/pk per attempt."""
+ conn = op.get_bind()
+ dialect_name = conn.dialect.name
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
+ batch_op.add_column(
+ sa.Column("try_number", sa.INTEGER(), autoincrement=False, nullable=False, default=1)
+ )
+
+ with op.batch_alter_table("task_instance_note", schema=None) as batch_op:
+ batch_op.drop_constraint("task_instance_note_ti_fkey", type_="foreignkey")
+ batch_op.create_foreign_key(
+ "task_instance_note_ti_fkey", "task_instance", ["ti_id"], ["id"], ondelete="CASCADE"
+ )
+
+ with op.batch_alter_table("task_instance_history", schema=None) as batch_op:
+ batch_op.alter_column(
+ "task_instance_id",
+ new_column_name="try_id",
+ existing_type=_get_uuid_type(dialect_name),
+ existing_nullable=False,
+ )
+ batch_op.drop_index("idx_tih_dag_run")
+ # This has to be in a separate batch, else on sqlite it throws `sqlalchemy.exc.CircularDependencyError`
+ # (and on non sqlite batching isn't "a thing", it issue alter tables fine)
+ with op.batch_alter_table("task_instance_history", schema=None) as batch_op:
+ batch_op.add_column(
+ sa.Column("task_instance_id", UUIDType(binary=False), autoincrement=False, nullable=False)
+ )
+
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("try_id", _get_uuid_type(dialect_name), nullable=False))
+ batch_op.create_unique_constraint("task_instance_try_id_uq", ["try_id"])
diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py
index 3f668e1618fd3..df4ecf6f56e65 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -49,7 +49,6 @@
not_,
or_,
text,
- tuple_,
update,
)
from sqlalchemy.dialects import postgresql
@@ -1831,7 +1830,7 @@ def schedule_tis(
"""
# Get list of TI IDs that do not need to executed, these are
# tasks using EmptyOperator and without on_execute_callback / on_success_callback
- dummy_ti_ids = []
+ empty_ti_ids = []
schedulable_ti_ids = []
for ti in schedulable_tis:
if TYPE_CHECKING:
@@ -1842,7 +1841,7 @@ def schedule_tis(
and not ti.task.on_success_callback
and not ti.task.outlets
):
- dummy_ti_ids.append((ti.task_id, ti.map_index))
+ empty_ti_ids.append(ti.id)
# check "start_trigger_args" to see whether the operator supports start execution from triggerer
# if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer
# this task.
@@ -1857,9 +1856,9 @@ def schedule_tis(
ti.try_number += 1
ti.defer_task(exception=None, session=session)
else:
- schedulable_ti_ids.append((ti.task_id, ti.map_index))
+ schedulable_ti_ids.append(ti.id)
else:
- schedulable_ti_ids.append((ti.task_id, ti.map_index))
+ schedulable_ti_ids.append(ti.id)
count = 0
@@ -1867,14 +1866,10 @@ def schedule_tis(
schedulable_ti_ids_chunks = chunks(
schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids)
)
- for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks:
+ for id_chunk in schedulable_ti_ids_chunks:
count += session.execute(
update(TI)
- .where(
- TI.dag_id == self.dag_id,
- TI.run_id == self.run_id,
- tuple_(TI.task_id, TI.map_index).in_(schedulable_ti_ids_chunk),
- )
+ .where(TI.id.in_(id_chunk))
.values(
state=TaskInstanceState.SCHEDULED,
scheduled_dttm=timezone.utcnow(),
@@ -1890,16 +1885,12 @@ def schedule_tis(
).rowcount
# Tasks using EmptyOperator should not be executed, mark them as success
- if dummy_ti_ids:
- dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or len(dummy_ti_ids))
- for dummy_ti_ids_chunk in dummy_ti_ids_chunks:
+ if empty_ti_ids:
+ dummy_ti_ids_chunks = chunks(empty_ti_ids, max_tis_per_query or len(empty_ti_ids))
+ for id_chunk in dummy_ti_ids_chunks:
count += session.execute(
update(TI)
- .where(
- TI.dag_id == self.dag_id,
- TI.run_id == self.run_id,
- tuple_(TI.task_id, TI.map_index).in_(dummy_ti_ids_chunk),
- )
+ .where(TI.id.in_(id_chunk))
.values(
state=TaskInstanceState.SUCCESS,
start_date=timezone.utcnow(),
diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py
index 11a40181f7861..f15b8e9b918b9 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -447,12 +447,10 @@ def clear_task_instances(
# taskinstance uuids:
task_instance_ids: list[str] = []
dag_bag = DagBag(read_dags_from_db=True)
- from airflow.models.taskinstancehistory import TaskInstanceHistory
for ti in tis:
task_instance_ids.append(ti.id)
- TaskInstanceHistory.record_ti(ti, session)
- ti.try_id = uuid7()
+ ti.prepare_db_for_next_try(session)
if ti.state == TaskInstanceState.RUNNING:
# If a task is cleared when running, set its state to RESTARTING so that
# the task is terminated and becomes eligible for retry.
@@ -477,11 +475,6 @@ def clear_task_instances(
ti.clear_next_method_args()
session.merge(ti)
- if task_instance_ids:
- # Clear all reschedules related to the ti to clear
- delete_qry = TR.__table__.delete().where(TR.ti_id.in_(task_instance_ids))
- session.execute(delete_qry)
-
if dag_run_state is not False and tis:
from airflow.models.dagrun import DagRun # Avoid circular import
@@ -1436,7 +1429,6 @@ def _handle_reschedule(
session.add(
TaskReschedule(
ti.id,
- ti.try_number,
actual_start_date,
ti.end_date,
reschedule_exception.reschedule_date,
@@ -1486,7 +1478,6 @@ class TaskInstance(Base, LoggingMixin):
end_date = Column(UtcDateTime)
duration = Column(Float)
state = Column(String(20))
- try_id = Column(UUIDType(binary=False), default=uuid7, unique=True, nullable=False)
try_number = Column(Integer, default=0)
max_tries = Column(Integer, server_default=text("-1"))
hostname = Column(String(1000))
@@ -1945,33 +1936,48 @@ def refresh_from_db(
:param keep_local_changes: Force all attributes to the values from the database if False (the default),
or if True don't overwrite locally set attributes
"""
- source = TaskInstance.get_task_instance(
+ query = select(
+ # Select the columns, not the ORM object, to bypass any session/ORM caching layer
+ c
+ for c in TaskInstance.__table__.columns
+ ).filter_by(
dag_id=self.dag_id,
- task_id=self.task_id,
run_id=self.run_id,
+ task_id=self.task_id,
map_index=self.map_index,
- lock_for_update=lock_for_update,
- session=session,
)
- if source:
- from sqlalchemy.orm import attributes
- source_state = inspect(source)
- if source_state is None:
- raise RuntimeError(f"Unable to inspect SQLAlchemy state of {type(source)}: {source}")
+ if lock_for_update:
+ query = query.with_for_update()
+
+ source = session.execute(query).mappings().one_or_none()
+ if source:
target_state = inspect(self)
if target_state is None:
raise RuntimeError(f"Unable to inspect SQLAlchemy state of {type(self)}: {self}")
- for name, attr in source_state.attrs.items():
- if keep_local_changes and target_state.attrs[name].history.has_changes():
+
+ # To deal with `@hybrid_property` we need to get the names from `mapper.columns`
+ for attr_name, col in target_state.mapper.columns.items():
+ if keep_local_changes and target_state.attrs[attr_name].history.has_changes():
continue
- val = attr.loaded_value
+ set_committed_value(self, attr_name, source[col.name])
- if val is not attributes.NO_VALUE:
- set_committed_value(self, name, val)
+ # ID may have changed, update SQLAs state and object tracking
+ newkey = session.identity_key(type(self), (self.id,))
+
+ # Delete anything under the new key
+ if newkey != target_state.key:
+ old = session.identity_map.get(newkey)
+ if old is not self and old is not None:
+ session.expunge(old)
+ target_state.key = newkey
+
+ if target_state.attrs.dag_run.loaded_value is not NO_VALUE:
+ dr_key = session.identity_key(type(self.dag_run), (self.dag_run.id,))
+ if (dr := session.identity_map.get(dr_key)) is not None:
+ set_committed_value(self, "dag_run", dr)
- target_state.key = source_state.key
else:
self.state = None
@@ -2019,32 +2025,6 @@ def key(self) -> TaskInstanceKey:
"""Returns a tuple that identifies the task instance uniquely."""
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
- @staticmethod
- def _set_state(ti: TaskInstance, state, session: Session) -> bool:
- if not isinstance(ti, TaskInstance):
- ti = session.scalars(
- select(TaskInstance).where(
- TaskInstance.task_id == ti.task_id,
- TaskInstance.dag_id == ti.dag_id,
- TaskInstance.run_id == ti.run_id,
- TaskInstance.map_index == ti.map_index,
- )
- ).one()
-
- if ti.state == state:
- return False
-
- current_time = timezone.utcnow()
- ti.log.debug("Setting task state for %s to %s", ti, state)
- ti.state = state
- ti.start_date = ti.start_date or current_time
- if ti.state in State.finished or ti.state == TaskInstanceState.UP_FOR_RETRY:
- ti.end_date = ti.end_date or current_time
- ti.duration = (ti.end_date - ti.start_date).total_seconds()
-
- session.merge(ti)
- return True
-
@provide_session
def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
"""
@@ -2054,7 +2034,21 @@ def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
:param session: SQLAlchemy ORM Session
:return: Was the state changed
"""
- return self._set_state(ti=self, state=state, session=session)
+ if self.state == state:
+ return False
+
+ current_time = timezone.utcnow()
+ self.log.debug("Setting task state for %s to %s", self, state)
+ if self not in session:
+ self.refresh_from_db(session)
+ self.state = state
+ self.start_date = self.start_date or current_time
+ if self.state in State.finished or self.state == TaskInstanceState.UP_FOR_RETRY:
+ self.end_date = self.end_date or current_time
+ self.duration = (self.end_date - self.start_date).total_seconds()
+ session.merge(self)
+ session.flush()
+ return True
@property
def is_premature(self) -> bool:
@@ -2062,6 +2056,14 @@ def is_premature(self) -> bool:
# is the task still in the retry waiting period?
return self.state == TaskInstanceState.UP_FOR_RETRY and not self.ready_for_retry()
+ def prepare_db_for_next_try(self, session: Session):
+ """Update the metadata with all the records needed to put this TI in queued for the next try."""
+ from airflow.models.taskinstancehistory import TaskInstanceHistory
+
+ TaskInstanceHistory.record_ti(self, session=session)
+ session.execute(delete(TaskReschedule).filter_by(ti_id=self.id))
+ self.id = uuid7()
+
@provide_session
def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
"""
@@ -2994,10 +2996,7 @@ def fetch_handle_failure_context(
# If the task instance is in the running state, it means it raised an exception and
# about to retry so we record the task instance history. For other states, the task
# instance was cleared and already recorded in the task instance history.
- from airflow.models.taskinstancehistory import TaskInstanceHistory
-
- TaskInstanceHistory.record_ti(ti, session=session)
- ti.try_id = uuid7()
+ ti.prepare_db_for_next_try(session)
ti.state = State.UP_FOR_RETRY
email_for_state = operator.attrgetter("email_on_retry")
@@ -3553,21 +3552,14 @@ def _get_inactive_asset_unique_keys(
def get_first_reschedule_date(self, context: Context) -> datetime | None:
"""Get the first reschedule date for the task instance."""
- # TODO: AIP-72: Remove this after `ti.run` is migrated to use Task SDK
- max_tries: int = self.max_tries or 0
-
if TYPE_CHECKING:
assert isinstance(self.task, BaseOperator)
- retries: int = self.task.retries or 0
- first_try_number = max_tries - retries + 1
-
with create_session() as session:
start_date = session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.ti_id == str(self.id),
- TaskReschedule.try_number >= first_try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
@@ -3715,6 +3707,7 @@ class TaskInstanceNote(Base):
],
name="task_instance_note_ti_fkey",
ondelete="CASCADE",
+ onupdate="CASCADE",
),
)
diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py b/airflow-core/src/airflow/models/taskinstancehistory.py
index 56b9e13d29c61..7e585424cb2f4 100644
--- a/airflow-core/src/airflow/models/taskinstancehistory.py
+++ b/airflow-core/src/airflow/models/taskinstancehistory.py
@@ -25,6 +25,7 @@
DateTime,
Float,
ForeignKeyConstraint,
+ Index,
Integer,
String,
UniqueConstraint,
@@ -32,7 +33,6 @@
select,
text,
)
-from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import relationship
from sqlalchemy_utils import UUIDType
@@ -62,11 +62,7 @@ class TaskInstanceHistory(Base):
"""
__tablename__ = "task_instance_history"
- try_id = Column(UUIDType(binary=False), nullable=False, primary_key=True)
- task_instance_id = Column(
- String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"),
- nullable=False,
- )
+ task_instance_id = Column(UUIDType(binary=False), nullable=False, primary_key=True)
task_id = Column(StringID(), nullable=False)
dag_id = Column(StringID(), nullable=False)
run_id = Column(StringID(), nullable=False)
@@ -150,6 +146,7 @@ def __init__(
"try_number",
name="task_instance_history_dtrt_uq",
),
+ Index("idx_tih_dag_run", dag_id, run_id),
)
@staticmethod
diff --git a/airflow-core/src/airflow/models/taskreschedule.py b/airflow-core/src/airflow/models/taskreschedule.py
index f82b95537ce83..e07d750ebdef7 100644
--- a/airflow-core/src/airflow/models/taskreschedule.py
+++ b/airflow-core/src/airflow/models/taskreschedule.py
@@ -55,7 +55,6 @@ class TaskReschedule(Base):
ForeignKey("task_instance.id", ondelete="CASCADE", name="task_reschedule_ti_fkey"),
nullable=False,
)
- try_number = Column(Integer, nullable=False)
start_date = Column(UtcDateTime, nullable=False)
end_date = Column(UtcDateTime, nullable=False)
duration = Column(Integer, nullable=False)
@@ -67,14 +66,12 @@ class TaskReschedule(Base):
def __init__(
self,
- task_instance_id: uuid.UUID,
- try_number: int,
+ ti_id: uuid.UUID,
start_date: datetime.datetime,
end_date: datetime.datetime,
reschedule_date: datetime.datetime,
) -> None:
- self.ti_id = task_instance_id
- self.try_number = try_number
+ self.ti_id = ti_id
self.start_date = start_date
self.end_date = end_date
self.reschedule_date = reschedule_date
@@ -85,7 +82,6 @@ def stmt_for_task_instance(
cls,
ti: TaskInstance,
*,
- try_number: int | None = None,
descending: bool = False,
) -> Select:
"""
@@ -93,14 +89,6 @@ def stmt_for_task_instance(
:param ti: the task instance to find task reschedules for
:param descending: If True then records are returned in descending order
- :param try_number: Look for TaskReschedule of the given try_number. Default is None which
- looks for the same try_number of the given task_instance.
:meta private:
"""
- if try_number is None:
- try_number = ti.try_number
- return (
- select(cls)
- .where(cls.ti_id == ti.id, cls.try_number == try_number)
- .order_by(desc(cls.id) if descending else asc(cls.id))
- )
+ return select(cls).where(cls.ti_id == ti.id).order_by(desc(cls.id) if descending else asc(cls.id))
diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py
index a563bbebb0af9..e6104a7f44330 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -92,7 +92,7 @@ class MappedClassProtocol(Protocol):
"2.9.2": "686269002441",
"2.10.0": "22ed7efa9da2",
"2.10.3": "5f2621c13b39",
- "3.0.0": "959e216a3abb",
+ "3.0.0": "29ce7909c52b",
}
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py
index 70b54c31dba25..690b0f7bd7642 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py
@@ -189,7 +189,13 @@ def time_freezer(self) -> Generator:
freezer.stop()
@pytest.fixture(autouse=True)
- def setup(self) -> None:
+ def setup(self):
+ clear_db_assets()
+ clear_db_runs()
+ clear_db_logs()
+
+ yield
+
clear_db_assets()
clear_db_runs()
clear_db_logs()
@@ -229,7 +235,7 @@ def create_asset_dag_run(self, session, num: int = 2):
class TestGetAssets(TestAssets):
def test_should_respond_200(self, test_client, session):
- self.create_assets()
+ assets1, asset2 = self.create_assets(session)
session.add(AssetModel("inactive", "inactive"))
session.commit()
@@ -243,7 +249,7 @@ def test_should_respond_200(self, test_client, session):
assert response_data == {
"assets": [
{
- "id": 1,
+ "id": assets1.id,
"name": "simple1",
"uri": "s3://bucket/key/1",
"group": "asset",
@@ -255,7 +261,7 @@ def test_should_respond_200(self, test_client, session):
"aliases": [],
},
{
- "id": 2,
+ "id": asset2.id,
"name": "simple2",
"uri": "s3://bucket/key/2",
"group": "asset",
@@ -271,10 +277,9 @@ def test_should_respond_200(self, test_client, session):
}
def test_should_show_inactive(self, test_client, session):
- self.create_assets()
+ asset1, asset2 = self.create_assets(session)
session.add(
- AssetModel(
- id=3,
+ asset3 := AssetModel(
name="simple3",
uri="s3://bucket/key/3",
group="asset",
@@ -295,7 +300,7 @@ def test_should_show_inactive(self, test_client, session):
assert response_data == {
"assets": [
{
- "id": 1,
+ "id": asset1.id,
"name": "simple1",
"uri": "s3://bucket/key/1",
"group": "asset",
@@ -307,7 +312,7 @@ def test_should_show_inactive(self, test_client, session):
"aliases": [],
},
{
- "id": 2,
+ "id": asset2.id,
"name": "simple2",
"uri": "s3://bucket/key/2",
"group": "asset",
@@ -319,7 +324,7 @@ def test_should_show_inactive(self, test_client, session):
"aliases": [],
},
{
- "id": 3,
+ "id": asset3.id,
"name": "simple3",
"uri": "s3://bucket/key/3",
"group": "asset",
@@ -624,11 +629,12 @@ def test_should_respect_page_size_limit_default(self, test_client):
class TestGetAssetEvents(TestAssets):
def test_should_respond_200(self, test_client, session):
- self.create_assets()
- self.create_assets_events()
- self.create_dag_run()
- self.create_asset_dag_run()
+ asset1, asset2 = self.create_assets(session)
+ self.create_assets_events(session)
+ self.create_dag_run(session)
+ self.create_asset_dag_run(session)
assets = session.query(AssetEvent).all()
+ session.commit()
assert len(assets) == 2
response = test_client.get("/assets/events")
assert response.status_code == 200
@@ -944,6 +950,8 @@ def test_should_respond_404(self, test_client):
class TestQueuedEventEndpoint(TestAssets):
def _create_asset_dag_run_queues(self, dag_id, asset_id, session):
+ session.query(AssetDagRunQueue).delete()
+ session.flush()
adrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id)
session.add(adrq)
session.commit()
@@ -955,9 +963,8 @@ class TestGetDagAssetQueuedEvents(TestQueuedEventEndpoint):
def test_should_respond_200(self, test_client, session, create_dummy_dag):
dag, _ = create_dummy_dag()
dag_id = dag.dag_id
- self.create_assets(session=session, num=1)
- asset_id = 1
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
+ (asset,) = self.create_assets(session=session, num=1)
+ self._create_asset_dag_run_queues(dag_id, asset.id, session)
response = test_client.get(
f"/dags/{dag_id}/assets/queuedEvents",
@@ -967,7 +974,7 @@ def test_should_respond_200(self, test_client, session, create_dummy_dag):
assert response.json() == {
"queued_events": [
{
- "asset_id": 1,
+ "asset_id": asset.id,
"dag_id": "dag",
"created_at": from_datetime_to_zulu_without_ms(DEFAULT_DATE),
}
@@ -1050,13 +1057,13 @@ def test_should_respond_404_valid_dag_no_adrq(self, test_client, session, create
class TestPostAssetEvents(TestAssets):
@pytest.mark.usefixtures("time_freezer")
def test_should_respond_200(self, test_client, session):
- self.create_assets(session)
- event_payload = {"asset_id": 1, "extra": {"foo": "bar"}}
+ (asset,) = self.create_assets(session, num=1)
+ event_payload = {"asset_id": asset.id, "extra": {"foo": "bar"}}
response = test_client.post("/assets/events", json=event_payload)
assert response.status_code == 200
assert response.json() == {
"id": mock.ANY,
- "asset_id": 1,
+ "asset_id": asset.id,
"uri": "s3://bucket/key/1",
"group": "asset",
"name": "simple1",
@@ -1088,13 +1095,13 @@ def test_invalid_attr_not_allowed(self, test_client, session):
@pytest.mark.usefixtures("time_freezer")
@pytest.mark.enable_redact
def test_should_mask_sensitive_extra(self, test_client, session):
- self.create_assets(session)
- event_payload = {"asset_id": 1, "extra": {"password": "bar"}}
+ (asset,) = self.create_assets(session, num=1)
+ event_payload = {"asset_id": asset.id, "extra": {"password": "bar"}}
response = test_client.post("/assets/events", json=event_payload)
assert response.status_code == 200
assert response.json() == {
"id": mock.ANY,
- "asset_id": 1,
+ "asset_id": asset.id,
"uri": "s3://bucket/key/1",
"group": "asset",
"name": "simple1",
@@ -1118,7 +1125,9 @@ class TestPostAssetMaterialize(TestAssets):
@pytest.fixture(autouse=True)
def create_dags(self, setup, dag_maker, session):
# Depend on 'setup' so it runs first. Otherwise it deletes what we create here.
- assets = {am.id: am.to_public() for am in self.create_assets(session=session, num=3)}
+ assets = {
+ i: am.to_public() for i, am in enumerate(self.create_assets(session=session, num=3), start=1)
+ }
with dag_maker(self.DAG_ASSET1_ID, schedule=None, session=session):
EmptyOperator(task_id="task", outlets=assets[1])
with dag_maker(self.DAG_ASSET2_ID_A, schedule=None, session=session):
@@ -1176,16 +1185,15 @@ class TestGetAssetQueuedEvents(TestQueuedEventEndpoint):
def test_should_respond_200(self, test_client, session, create_dummy_dag):
dag, _ = create_dummy_dag()
dag_id = dag.dag_id
- self.create_assets(session=session, num=1)
- asset_id = 1
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
+ (asset,) = self.create_assets(session=session, num=1)
+ self._create_asset_dag_run_queues(dag_id, asset.id, session)
- response = test_client.get(f"/assets/{asset_id}/queuedEvents")
+ response = test_client.get(f"/assets/{asset.id}/queuedEvents")
assert response.status_code == 200
assert response.json() == {
"queued_events": [
{
- "asset_id": asset_id,
+ "asset_id": asset.id,
"dag_id": "dag",
"created_at": from_datetime_to_zulu_without_ms(DEFAULT_DATE),
}
@@ -1212,14 +1220,13 @@ class TestDeleteAssetQueuedEvents(TestQueuedEventEndpoint):
def test_should_respond_204(self, test_client, session, create_dummy_dag):
dag, _ = create_dummy_dag()
dag_id = dag.dag_id
- self.create_assets(session=session, num=1)
- asset_id = 1
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
+ (asset,) = self.create_assets(session=session, num=1)
+ self._create_asset_dag_run_queues(dag_id, asset.id, session)
- assert session.get(AssetDagRunQueue, (asset_id, dag_id)) is not None
- response = test_client.delete(f"/assets/{asset_id}/queuedEvents")
+ assert session.get(AssetDagRunQueue, (asset.id, dag_id)) is not None
+ response = test_client.delete(f"/assets/{asset.id}/queuedEvents")
assert response.status_code == 204
- assert session.get(AssetDagRunQueue, (asset_id, dag_id)) is None
+ assert session.get(AssetDagRunQueue, (asset.id, dag_id)) is None
check_last_log(session, dag_id=None, event="delete_asset_queued_events", logical_date=None)
def test_should_respond_401(self, unauthenticated_test_client):
@@ -1240,15 +1247,14 @@ class TestDeleteDagAssetQueuedEvent(TestQueuedEventEndpoint):
def test_delete_should_respond_204(self, test_client, session, create_dummy_dag):
dag, _ = create_dummy_dag()
dag_id = dag.dag_id
- self.create_assets(session=session, num=1)
- asset_id = 1
+ (asset,) = self.create_assets(session=session, num=1)
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
+ self._create_asset_dag_run_queues(dag_id, asset.id, session)
adrq = session.query(AssetDagRunQueue).all()
assert len(adrq) == 1
response = test_client.delete(
- f"/dags/{dag_id}/assets/{asset_id}/queuedEvents",
+ f"/dags/{dag_id}/assets/{asset.id}/queuedEvents",
)
assert response.status_code == 204
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py
index 11f3846cacd78..b391fc66492b7 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py
@@ -25,6 +25,7 @@
import pytest
from itsdangerous.url_safe import URLSafeSerializer
+from uuid6 import uuid7
from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
from airflow.models.dag import DAG
@@ -72,6 +73,18 @@ def add_one(x: int):
self.app.state.dag_bag.bag_dag(dag)
+ for ti in dr.task_instances:
+ ti.try_number = 1
+ ti.hostname = "localhost"
+ session.merge(ti)
+ dag.clear()
+ for ti in dr.task_instances:
+ ti.try_number = 2
+ ti.id = str(uuid7())
+ ti.hostname = "localhost"
+ session.merge(ti)
+ session.flush()
+
# Add dummy dag for checking picking correct log with same task_id and different dag_id case.
with dag_maker(
f"{self.DAG_ID}_copy", start_date=timezone.parse(self.default_time), session=session
@@ -85,27 +98,21 @@ def add_one(x: int):
)
self.app.state.dag_bag.bag_dag(dummy_dag)
- for ti in dr.task_instances:
- ti.try_number = 1
- ti.hostname = "localhost"
- session.merge(ti)
for ti in dr2.task_instances:
ti.try_number = 1
ti.hostname = "localhost"
session.merge(ti)
- session.flush()
- dag.clear()
dummy_dag.clear()
- for ti in dr.task_instances:
- ti.try_number = 2
- ti.hostname = "localhost"
- session.merge(ti)
for ti in dr2.task_instances:
ti.try_number = 2
+ ti.id = str(uuid7())
ti.hostname = "localhost"
session.merge(ti)
+ session.flush()
session.flush()
+ ...
+
@pytest.fixture
def configure_loggers(self, tmp_path, create_log_template):
self.log_dir = tmp_path
@@ -165,6 +172,7 @@ def test_should_respond_200_json(self, try_number):
)
expected_filename = f"{self.log_dir}/dag_id={self.DAG_ID}/run_id={self.RUN_ID}/task_id={self.TASK_ID}/attempt={try_number}.log"
log_content = "Log for testing." if try_number == 1 else "Log for testing 2."
+ assert response.status_code == 200, response.json()
resp_contnt = response.json()["content"]
assert expected_filename in resp_contnt[0]["sources"]
assert log_content in resp_contnt[2]["event"]
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 4678e6c30eaf1..7b7b3e531624d 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -96,7 +96,7 @@ def create_task_instances(
dag_id: str = "example_python_operator",
update_extras: bool = True,
task_instances=None,
- dag_run_state=State.RUNNING,
+ dag_run_state=DagRunState.RUNNING,
with_ti_history=False,
):
"""Method to create task instances using kwargs and default arguments"""
@@ -133,6 +133,7 @@ def create_task_instances(
state=dag_run_state,
)
session.add(dr)
+ session.flush()
ti = TaskInstance(task=tasks[i], **self.ti_init)
session.add(ti)
ti.dag_run = dr
@@ -142,18 +143,20 @@ def create_task_instances(
setattr(ti, key, value)
tis.append(ti)
- session.commit()
+ session.flush()
+
if with_ti_history:
for ti in tis:
ti.try_number = 1
session.merge(ti)
- session.commit()
- dag.clear()
+ session.flush()
+ dag.clear(session=session)
for ti in tis:
ti.try_number = 2
ti.queue = "default_queue"
session.merge(ti)
- session.commit()
+ session.flush()
+ session.commit()
return tis
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 456c3f9fa2a48..c933c0091d53f 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -707,7 +707,7 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan
assert trs[0].task_instance.dag_id == "dag"
assert trs[0].task_instance.task_id == "test_ti_update_state_to_reschedule"
assert trs[0].task_instance.run_id == "test"
- assert trs[0].try_number == 0
+ assert trs[0].ti_id == tis[0].id
assert trs[0].start_date == instant
assert trs[0].end_date == DEFAULT_END_DATE
assert trs[0].reschedule_date == timezone.parse("2024-10-31T11:03:00+00:00")
@@ -763,20 +763,21 @@ def test_ti_update_state_handle_retry(self, client, session, create_task_instanc
assert response.status_code == 204
assert response.text == ""
- session.expire_all()
-
- ti = session.get(TaskInstance, ti.id)
+ ti = session.scalar(
+ select(TaskInstance).filter_by(task_id=ti.task_id, run_id=ti.run_id, dag_id=ti.dag_id)
+ )
+ # ti = session.get(TaskInstance, ti.id)
assert ti.state == State.UP_FOR_RETRY
assert ti.next_method is None
assert ti.next_kwargs is None
tih = (
session.query(TaskInstanceHistory)
- .where(TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.task_instance_id == ti.id)
+ .where(TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.run_id == ti.run_id)
.one()
)
- assert tih.try_id
- assert tih.try_id != ti.try_id
+ assert tih.task_instance_id
+ assert tih.task_instance_id != ti.id
def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
# we just want to fail in this test, no need to retry
@@ -1204,8 +1205,7 @@ def test_get_start_date(self, client, session, create_task_instance):
session=session,
)
tr = TaskReschedule(
- task_instance_id=ti.id,
- try_number=1,
+ ti_id=ti.id,
start_date=timezone.datetime(2024, 1, 1),
end_date=timezone.datetime(2024, 1, 1, 1),
reschedule_date=timezone.datetime(2024, 1, 1, 2),
@@ -1222,37 +1222,6 @@ def test_get_start_date_not_found(self, client):
response = client.get(f"/execution/task-reschedules/{ti_id}/start_date")
assert response.json() is None
- def test_get_start_date_with_try_number(self, client, session, create_task_instance):
- # Create multiple reschedules
- dates = [
- timezone.datetime(2024, 1, 1),
- timezone.datetime(2024, 1, 2),
- timezone.datetime(2024, 1, 3),
- ]
-
- ti = create_task_instance(
- task_id="test_get_start_date_with_try_number",
- state=State.RUNNING,
- start_date=timezone.datetime(2024, 1, 1),
- session=session,
- )
-
- for i, date in enumerate(dates, 1):
- tr = TaskReschedule(
- task_instance_id=ti.id,
- try_number=i,
- start_date=date,
- end_date=date.replace(hour=1),
- reschedule_date=date.replace(hour=2),
- )
- session.add(tr)
- session.commit()
-
- # Test getting start date for try_number 2
- response = client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2")
- assert response.status_code == 200
- assert response.json() == "2024-01-02T00:00:00Z"
-
class TestGetCount:
def setup_method(self):
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 7e98c1683fad5..b1119628e22a7 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -1711,19 +1711,21 @@ def _create_dagruns():
# First pass we'll grab 6 of the 8 tasks (limited by max_tis_per_query)
res = self.job_runner._critical_section_enqueue_task_instances(session)
assert res == 6
+ session.flush()
for ti in tis1[:3] + tis2[:3]:
- ti.refresh_from_db()
+ ti.refresh_from_db(session)
assert ti.state == TaskInstanceState.QUEUED
for ti in tis1[3:] + tis2[3:]:
- ti.refresh_from_db()
+ ti.refresh_from_db(session)
assert ti.state == TaskInstanceState.SCHEDULED
# The remaining TIs are queued
res = self.job_runner._critical_section_enqueue_task_instances(session)
assert res == 2
+ session.flush()
for ti in tis1 + tis2:
- ti.refresh_from_db()
+ ti.refresh_from_db(session)
assert ti.state == State.QUEUED
@pytest.mark.parametrize(
@@ -2555,8 +2557,9 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak
session = settings.Session()
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance("dummy")
+ ti = dr.get_task_instance("dummy", session)
ti.set_state(state, session)
+ session.flush()
with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
self.job_runner._do_scheduling(session)
@@ -2579,7 +2582,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak
@pytest.mark.parametrize(
"state, expected_callback_msg", [(State.SUCCESS, "success"), (State.FAILED, "task_failure")]
)
- def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_maker):
+ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_maker, session):
"""
Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor.
"""
@@ -2587,6 +2590,7 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak
dag_id="test_dagrun_callbacks_are_called",
on_success_callback=lambda x: print("success"),
on_failure_callback=lambda x: print("failed"),
+ session=session,
):
EmptyOperator(task_id="dummy")
@@ -2598,10 +2602,9 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak
self.job_runner.dagbag = dag_maker.dagbag
- session = settings.Session()
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance("dummy")
+ ti = dr.get_task_instance("dummy", session)
ti.set_state(state, session)
with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
@@ -2611,8 +2614,8 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak
dag_listener.success = []
dag_listener.failure = []
+
session.rollback()
- session.close()
def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, session):
with dag_maker(
@@ -2700,12 +2703,13 @@ def mock_send_dag_callbacks_to_processor(*args, **kwargs):
session.close()
@pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED])
- def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state, dag_maker):
+ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state, dag_maker, session):
"""
Test if no on_*_callback are defined on DAG, Callbacks not registered and sent to DAG Processor
"""
with dag_maker(
dag_id="test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined",
+ session=session,
):
BashOperator(task_id="test_task", bash_command="echo hi")
@@ -2714,9 +2718,8 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta
self.job_runner._send_dag_callbacks_to_processor = mock.Mock()
- session = settings.Session()
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance("test_task")
+ ti = dr.get_task_instance("test_task", session)
ti.set_state(state, session)
with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
@@ -2729,7 +2732,6 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta
assert call_args[1] is None
session.rollback()
- session.close()
@pytest.mark.parametrize("state, msg", [[State.SUCCESS, "success"], [State.FAILED, "task_failure"]])
def test_dagrun_callbacks_are_added_when_callbacks_are_defined(self, state, msg, dag_maker):
@@ -3483,10 +3485,10 @@ def run_with_error(ti, ignore_ti_state=False):
assert ti.state == State.UP_FOR_RETRY
assert ti.try_number == 1
- with create_session() as session:
- ti.refresh_from_db(lock_for_update=True, session=session)
- ti.state = State.SCHEDULED
- session.merge(ti)
+ ti.refresh_from_db(lock_for_update=True, session=session)
+ ti.state = State.SCHEDULED
+ session.merge(ti)
+ session.commit()
# To verify that task does get re-queued.
executor.do_update = True
@@ -6223,12 +6225,10 @@ def test_update_dagrun_state_for_paused_dag(self, dag_maker, session):
run_type=DagRunType.SCHEDULED,
)
scheduled_run.last_scheduling_decision = datetime.datetime.now(timezone.utc) - timedelta(minutes=1)
- ti = scheduled_run.get_task_instances()[0]
+ ti = scheduled_run.get_task_instances(session=session)[0]
ti.set_state(TaskInstanceState.RUNNING)
- dm = DagModel.get_dagmodel(dag.dag_id)
+ dm = DagModel.get_dagmodel(dag.dag_id, session)
dm.is_paused = True
- session.merge(dm)
- session.merge(ti)
session.flush()
assert scheduled_run.state == State.RUNNING
@@ -6252,7 +6252,8 @@ def test_update_dagrun_state_for_paused_dag(self, dag_maker, session):
assert prior_last_scheduling_decision == scheduled_run.last_scheduling_decision
# Once the TI is in a terminal state though, DagRun goes to success
- ti.set_state(TaskInstanceState.SUCCESS)
+ ti.set_state(TaskInstanceState.SUCCESS, session=session)
+
self.job_runner._update_dag_run_state_for_paused_dags(session=session)
(scheduled_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.SCHEDULED, session=session)
assert scheduled_run.state == State.SUCCESS
@@ -6265,12 +6266,10 @@ def test_update_dagrun_state_for_paused_dag_not_for_backfill(self, dag_maker, se
# Backfill run
backfill_run = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB)
backfill_run.last_scheduling_decision = datetime.datetime.now(timezone.utc) - timedelta(minutes=1)
- ti = backfill_run.get_task_instances()[0]
- ti.set_state(TaskInstanceState.SUCCESS)
- dm = DagModel.get_dagmodel(dag.dag_id)
+ ti = backfill_run.get_task_instances(session=session)[0]
+ ti.set_state(TaskInstanceState.SUCCESS, session=session)
+ dm = DagModel.get_dagmodel(dag.dag_id, session=session)
dm.is_paused = True
- session.merge(dm)
- session.merge(ti)
session.flush()
assert backfill_run.state == State.RUNNING
@@ -6278,7 +6277,7 @@ def test_update_dagrun_state_for_paused_dag_not_for_backfill(self, dag_maker, se
scheduler_job = Job(executor=self.null_exec)
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- self.job_runner._update_dag_run_state_for_paused_dags()
+ self.job_runner._update_dag_run_state_for_paused_dags(session=session)
session.flush()
(backfill_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.BACKFILL_JOB, session=session)
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py
index f6626c3d7816b..9bcae8cceba51 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -76,8 +76,8 @@ def test_clear_task_instances(self, dag_maker):
# do the incrementing of try_number ordinarily handled by scheduler
ti0.try_number += 1
ti1.try_number += 1
- session.merge(ti0)
- session.merge(ti1)
+ ti0 = session.merge(ti0)
+ ti1 = session.merge(ti1)
session.commit()
# we use order_by(task_id) here because for the test DAG structure of ours
@@ -87,8 +87,9 @@ def test_clear_task_instances(self, dag_maker):
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
- ti0.refresh_from_db()
- ti1.refresh_from_db()
+ ti0.refresh_from_db(session)
+ ti1.refresh_from_db(session)
+
# Next try to run will be try 2
assert ti0.state is None
assert ti0.try_number == 1
@@ -530,11 +531,7 @@ def test_clear_task_instances_with_task_reschedule(self, dag_maker):
with create_session() as session:
def count_task_reschedule(ti):
- return (
- session.query(TaskReschedule)
- .filter(TaskReschedule.ti_id == ti.id, TaskReschedule.try_number == 1)
- .count()
- )
+ return session.query(TaskReschedule).filter(TaskReschedule.ti_id == ti.id).count()
assert count_task_reschedule(ti0) == 1
assert count_task_reschedule(ti1) == 1
@@ -626,12 +623,14 @@ def test_dag_clear(self, dag_maker):
assert ti0.try_number == 1
dag.clear()
ti0.refresh_from_db()
+ ti1.refresh_from_db()
assert ti0.try_number == 1
assert ti0.state == State.NONE
assert ti0.max_tries == 1
assert ti1.max_tries == 2
- session.get(TaskInstance, ti1.id).try_number += 1
+ session.add(ti1)
+ ti1.try_number += 1
session.commit()
ti1.run()
@@ -747,29 +746,37 @@ def test_operator_clear(self, dag_maker, session):
run_type=DagRunType.SCHEDULED,
)
- ti1, ti2 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
+ ti1, ti2 = sorted(dr.get_task_instances(session=session), key=lambda ti: ti.task_id)
ti1.task = op1
ti2.task = op2
session.get(TaskInstance, ti2.id).try_number += 1
session.commit()
- ti2.run()
+ ti2.run(session=session)
# Dependency not met
assert ti2.try_number == 1
assert ti2.max_tries == 1
- op2.clear(upstream=True)
+ op2.clear(upstream=True, session=session)
+ ti1.refresh_from_db(session)
+ ti2.refresh_from_db(session)
# max tries will be set to retries + curr try number == 1 + 1 == 2
- assert session.get(TaskInstance, ti2.id).max_tries == 2
+ assert ti2.max_tries == 2
- session.get(TaskInstance, ti1.id).try_number += 1
+ ti1.try_number += 1
+ session.merge(ti1)
session.commit()
- ti1.run()
+
+ ti1.run(session=session)
+ ti1.refresh_from_db(session)
+ ti2.refresh_from_db(session)
assert ti1.try_number == 1
- session.get(TaskInstance, ti2.id).try_number += 1
- session.commit()
- ti2.run(ignore_ti_state=True)
+ ti2.try_number += 1
+ session.add(ti2)
+ session.flush()
+ ti2.run(ignore_ti_state=True, session=session)
+ ti2.refresh_from_db(session)
# max_tries is 0 because there is no task instance in db for ti1
# so clear won't change the max_tries.
assert ti1.max_tries == 0
diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py
index dacfcf7a75a22..35f4ec7964850 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -2403,8 +2403,7 @@ def printx(x):
ti = session.query(TaskInstance).filter_by(**filter_kwargs).one()
tr = TaskReschedule(
- task_instance_id=ti.id,
- try_number=ti.try_number,
+ ti_id=ti.id,
start_date=timezone.datetime(2017, 1, 1),
end_date=timezone.datetime(2017, 1, 2),
reschedule_date=timezone.datetime(2017, 1, 1),
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py
index 1fc9de2a3b600..acbcb1939fd7a 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -724,11 +724,14 @@ def run_with_error(ti):
# clearing it first
dag.clear()
- session.get(TaskInstance, ti.id).try_number += 1
+ ti.refresh_from_db(session)
+ ti.try_number += 1
+ session.add(ti)
session.commit()
# third run -- up for retry
run_with_error(ti)
+ ti.refresh_from_db()
assert ti.state == State.UP_FOR_RETRY
assert ti.try_number == 3
@@ -901,7 +904,7 @@ def run_ti_and_assert(
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)
done, fail = False, True
- run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 1)
+ run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 0)
# scheduler would create a new try here
with create_session() as session:
@@ -1005,7 +1008,7 @@ def run_ti_and_assert(
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)
done, fail = False, True
- run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 1)
+ run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 1, 0)
with create_session() as session:
session.get(TaskInstance, ti.id).try_number += 1
@@ -1041,7 +1044,6 @@ def func():
).expand(poke_interval=[0])
ti = dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
ti.task = task
- assert ti.try_number == 0
def run_ti_and_assert(
run_date,
@@ -3537,8 +3539,7 @@ def test_handle_failure_task_undefined(self, create_task_instance):
del ti.task
ti.handle_failure("test ti.task undefined")
- @provide_session
- def test_handle_failure_fail_fast(self, create_dummy_dag, session=None):
+ def test_handle_failure_fail_fast(self, create_dummy_dag, session):
start_date = timezone.datetime(2016, 6, 1)
clear_db_runs()
@@ -3588,7 +3589,7 @@ def test_handle_failure_fail_fast(self, create_dummy_dag, session=None):
ti_ff = TI(task=fail_task, run_id=dr.run_id)
ti_ff.state = State.FAILED
session.add(ti_ff)
- session.flush()
+ session.commit()
ti_ff.handle_failure("test retry handling")
assert ti1.state == State.SUCCESS
@@ -3880,7 +3881,6 @@ def test_refresh_from_db(self, create_task_instance):
"end_date": run_date + datetime.timedelta(days=1, seconds=1, milliseconds=234),
"duration": 1.234,
"state": State.SUCCESS,
- "try_id": mock.ANY,
"try_number": 1,
"max_tries": 1,
"hostname": "some_unique_hostname",
@@ -4023,19 +4023,18 @@ def test_task_instance_history_is_created_when_ti_goes_for_retry(self, dag_maker
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = task
- try_id = ti.try_id
+ try_id = ti.id
with pytest.raises(AirflowException):
ti.run()
ti = session.query(TaskInstance).one()
- # the ti.try_id should be different from the previous one
- assert ti.try_id != try_id
+ # the ti.id should be different from the previous one
+ assert ti.id != try_id
assert ti.state == State.UP_FOR_RETRY
assert session.query(TaskInstance).count() == 1
tih = session.query(TaskInstanceHistory).all()
assert len(tih) == 1
# the new try_id should be different from what's recorded in tih
- assert tih[0].try_id == try_id
- assert tih[0].try_id != ti.try_id
+ assert str(tih[0].task_instance_id) == try_id
@pytest.mark.skip(
reason="This test has some issues that were surfaced when dag_maker started allowing multiple serdag versions. Issue #48539 will track fixing this."
diff --git a/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py
index c56d0b84b3553..7b4bd22aa0436 100644
--- a/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py
+++ b/airflow-core/tests/unit/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -88,8 +88,7 @@ def _create_task_reschedule(self, ti, minutes: int | list[int]):
dt = ti.logical_date + timedelta(minutes=minutes_timedelta)
trs.append(
TaskReschedule(
- task_instance_id=ti.id,
- try_number=ti.try_number,
+ ti_id=ti.id,
start_date=dt,
end_date=dt,
reschedule_date=dt,
diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py
index 83197298733ce..4261ae80cd0ce 100644
--- a/devel-common/src/tests_common/test_utils/mock_executor.py
+++ b/devel-common/src/tests_common/test_utils/mock_executor.py
@@ -78,6 +78,7 @@ def sort_by(item):
state = self.mock_task_results[key]
ti.set_state(state, session=session)
self.change_state(key, state)
+ session.flush()
def terminate(self):
pass
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py
index 20636cd07059f..5d0c99dc354bb 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -480,6 +480,7 @@ def f():
branch_ti.set_state(TaskInstanceState.SUCCESS, session=session)
dr.task_instance_scheduling_decisions(session=session)
branch_2_ti = dr.get_task_instance(task_id="branch_2", session=session)
+ branch_2_ti.task = self.branch_2
assert branch_2_ti.state == TaskInstanceState.SKIPPED
branch_2_ti.set_state(None)
branch_2_ti.run()
@@ -761,6 +762,7 @@ def test_clear_skipped_downstream_task(self):
sc_ti.set_state(TaskInstanceState.SUCCESS, session=session)
dr.task_instance_scheduling_decisions(session=session)
op1_ti = dr.get_task_instance(task_id="op1", session=session)
+ op1_ti.task = self.op1
assert op1_ti.state == TaskInstanceState.SKIPPED
op1_ti.set_state(None)
op1_ti.run()
@@ -1709,6 +1711,7 @@ def f():
branch_2_ti = dr.get_task_instance(task_id="branch_2", session=session)
# FIXME
if self.opcls != BranchExternalPythonOperator:
+ branch_2_ti.task = self.branch_2
assert branch_2_ti.state == TaskInstanceState.SKIPPED
branch_2_ti.set_state(None, session=session)
branch_2_ti.run()