Skip to content

Commit c9585c5

Browse files
committed
Remove select_column option in TaskInstance.get_task_instance
Fundamentally what's going on here is we need a TaskInstance object instead of a Row object when sending over the wire in RPC call. But the full story on this one is actually somewhat complicated. It was back in 2.2.0 in #25312 when we converted to query with the column attrs instead of the TI object (#28900 only refactored this logic into a function). The reason was to avoid locking the dag_run table since TI newly had a dag_run relationship attr. Now, this causes a problem with AIP-44 because the RPC api does not know how to serialize a Row object. This PR switches back to querying a TaskInstance object, but avoids locking dag_run by using lazy_load option. Meanwhile, since try_number is a horrible attribute (which gives you a different answer depending on the state), we have to switch it back to look at the underlying private attr instead of the public accesor.
1 parent 0723a8f commit c9585c5

2 files changed

Lines changed: 24 additions & 13 deletions

File tree

airflow/models/taskinstance.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
from sqlalchemy.ext.associationproxy import association_proxy
6262
from sqlalchemy.ext.mutable import MutableDict
63-
from sqlalchemy.orm import reconstructor, relationship
63+
from sqlalchemy.orm import lazyload, reconstructor, relationship
6464
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
6565
from sqlalchemy.sql.expression import case, select
6666

@@ -521,7 +521,6 @@ def _refresh_from_db(
521521
task_id=task_instance.task_id,
522522
run_id=task_instance.run_id,
523523
map_index=task_instance.map_index,
524-
select_columns=True,
525524
lock_for_update=lock_for_update,
526525
session=session,
527526
)
@@ -532,8 +531,7 @@ def _refresh_from_db(
532531
task_instance.end_date = ti.end_date
533532
task_instance.duration = ti.duration
534533
task_instance.state = ti.state
535-
# Since we selected columns, not the object, this is the raw value
536-
task_instance.try_number = ti.try_number
534+
task_instance.try_number = ti._try_number
537535
task_instance.max_tries = ti.max_tries
538536
task_instance.hostname = ti.hostname
539537
task_instance.unixname = ti.unixname
@@ -911,7 +909,7 @@ def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
911909
912910
:meta private:
913911
"""
914-
if task_instance.state == TaskInstanceState.RUNNING.RUNNING:
912+
if task_instance.state == TaskInstanceState.RUNNING:
915913
return task_instance._try_number
916914
return task_instance._try_number + 1
917915

@@ -1792,18 +1790,18 @@ def get_task_instance(
17921790
run_id: str,
17931791
task_id: str,
17941792
map_index: int,
1795-
select_columns: bool = False,
17961793
lock_for_update: bool = False,
17971794
session: Session = NEW_SESSION,
17981795
) -> TaskInstance | TaskInstancePydantic | None:
17991796
query = (
1800-
session.query(*TaskInstance.__table__.columns) if select_columns else session.query(TaskInstance)
1801-
)
1802-
query = query.filter_by(
1803-
dag_id=dag_id,
1804-
run_id=run_id,
1805-
task_id=task_id,
1806-
map_index=map_index,
1797+
session.query(TaskInstance)
1798+
.options(lazyload("dag_run")) # lazy load dag run to avoid locking it
1799+
.filter_by(
1800+
dag_id=dag_id,
1801+
run_id=run_id,
1802+
task_id=task_id,
1803+
map_index=map_index,
1804+
)
18071805
)
18081806

18091807
if lock_for_update:

tests/models/test_taskinstance.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4561,3 +4561,16 @@ def test_taskinstance_with_note(create_task_instance, session):
45614561

45624562
assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None
45634563
assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None
4564+
4565+
4566+
def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
4567+
with dag_maker():
4568+
BashOperator(task_id="hello", bash_command="hi")
4569+
dag_maker.create_dagrun(state="success")
4570+
ti = session.scalar(select(TaskInstance))
4571+
assert ti.task_id == "hello" # just to confirm...
4572+
assert ti.try_number == 1 # starts out as 1
4573+
ti.refresh_from_db()
4574+
assert ti.try_number == 1 # stays 1
4575+
ti.refresh_from_db()
4576+
assert ti.try_number == 1 # stays 1

0 commit comments

Comments
 (0)