Skip to content
7 changes: 4 additions & 3 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,19 @@ def _parse_file_entrypoint():

import structlog

from airflow.sdk.execution_time import comms, task_runner
from airflow.sdk.execution_time.comms import CommsDecoder
from airflow.sdk.execution_time.task_runner import SupervisorComms

# Parse DAG file, send JSON back up!
comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager](
comms_decoder = CommsDecoder[ToDagProcessor, ToManager](
body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
)

msg = comms_decoder._get_response()
if not isinstance(msg, DagFileParseRequest):
raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}")

task_runner.SUPERVISOR_COMMS = comms_decoder
SupervisorComms().set_comms(comms_decoder)
log = structlog.get_logger(logger_name="task")

result = _parse_file(msg, log)
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,10 @@ async def init_comms(self):
"""
Set up the communications pipe between this process and the supervisor.

This also sets up the SUPERVISOR_COMMS so that TaskSDK code can work as expected too (but that will
This also sets up the supervisor-comms so that TaskSDK code can work as expected too (but that will
need to be wrapped in an ``sync_to_async()`` call)
"""
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.task_runner import SupervisorComms

# Yes, we read and write to stdin! It's a socket, not a normal stdin.
reader, writer = await asyncio.open_connection(sock=socket(fileno=0))
Expand All @@ -948,7 +948,7 @@ async def init_comms(self):
async_reader=reader,
)

task_runner.SUPERVISOR_COMMS = self.comms_decoder
SupervisorComms().set_comms(self.comms_decoder)

msg = await self.comms_decoder._aget_response(expect_id=0)

Expand Down
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any

from sqlalchemy import Integer, MetaData, String, text
Expand Down Expand Up @@ -96,3 +97,9 @@ class TaskInstanceDependencies(Base):
dag_id: Mapped[str] = mapped_column(StringID(), nullable=False)
run_id: Mapped[str] = mapped_column(StringID(), nullable=False)
map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1"))


def is_client_process_context() -> bool:
"""Check if we are in an execution context (Task, Dag Parser or Triggerer perhaps)."""
process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT", "").lower()
return process_context == "client"
13 changes: 3 additions & 10 deletions airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import json
import logging
import re
import sys
import warnings
from contextlib import suppress
from json import JSONDecodeError
Expand All @@ -34,7 +33,7 @@
from airflow._shared.secrets_masker import mask_secret
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.base import ID_LEN, Base, is_client_process_context
from airflow.models.crypto import get_fernet
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -499,13 +498,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None)
:param team_name: Team name associated to the task trying to access the connection (if any)
:return: connection
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
from airflow.sdk import Connection as TaskSDKConnection
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType

Expand Down Expand Up @@ -589,7 +582,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s

@classmethod
def from_json(cls, value, conn_id=None) -> Connection:
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
from airflow.sdk import Connection as TaskSDKConnection

warnings.warn(
Expand Down
35 changes: 5 additions & 30 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import contextlib
import json
import logging
import sys
import warnings
from typing import TYPE_CHECKING, Any

Expand All @@ -30,7 +29,7 @@

from airflow._shared.secrets_masker import mask_secret
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.models.base import ID_LEN, Base
from airflow.models.base import ID_LEN, Base, is_client_process_context
from airflow.models.crypto import get_fernet
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -148,13 +147,7 @@ def get(
:param deserialize_json: Deserialize the value to a Python dict
:param team_name: Team name associated to the task trying to access the variable (if any)
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.get from `airflow.models` is deprecated."
"Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -208,13 +201,7 @@ def set(
:param team_name: Team name associated to the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.set from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -339,13 +326,7 @@ def update(
:param team_name: Team name associated to the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.update from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.",
Expand Down Expand Up @@ -405,13 +386,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non
:param team_name: Team name associated to the task trying to delete the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.delete from `airflow.models` is deprecated."
"Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def fn(moment): ...
assert "got an unexpected keyword argument 'not_exists_arg'" in str(err)

@pytest.mark.asyncio
@patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True)
@patch("airflow.sdk.execution_time.task_runner.SupervisorComms._comms", create=True)
async def test_invalid_trigger(self, supervisor_builder):
"""Test the behaviour when we try to run an invalid Trigger"""
workload = workloads.RunTrigger.model_construct(
Expand Down Expand Up @@ -437,7 +437,7 @@ async def test_trigger_kwargs_serialization_cleanup(self, session):
await runner.cleanup_finished_triggers()

@pytest.mark.asyncio
@patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True)
@patch("airflow.sdk.execution_time.task_runner.SupervisorComms._comms", create=True)
async def test_sync_state_to_supervisor(self, supervisor_builder):
trigger_runner = TriggerRunner()
trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder)
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/tests/unit/models/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import os
import re
import sys
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -353,6 +354,7 @@ def test_extra_dejson(self):
}

@mock.patch("airflow.sdk.Connection.get")
@mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"})
def test_get_connection_from_secrets_task_sdk_success(self, mock_get):
"""Test the get_connection_from_secrets method with Task SDK success path."""
from airflow.sdk import Connection as SDKConnection
Expand All @@ -361,7 +363,6 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get):
mock_get.return_value = expected_connection

mock_task_runner = mock.MagicMock()
mock_task_runner.SUPERVISOR_COMMS = True

with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}):
result = Connection.get_connection_from_secrets("test_conn")
Expand All @@ -370,10 +371,10 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get):
assert result.conn_type == "test_type"

@mock.patch("airflow.sdk.Connection")
@mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"})
def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_connection):
"""Test the get_connection_from_secrets method with Task SDK not found path."""
mock_task_runner = mock.MagicMock()
mock_task_runner.SUPERVISOR_COMMS = True

mock_task_sdk_connection.get.side_effect = AirflowRuntimeError(
error=ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
Expand All @@ -383,7 +384,6 @@ def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_conn
with pytest.raises(AirflowNotFoundException):
Connection.get_connection_from_secrets("test_conn")

@mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None})
@mock.patch("airflow.sdk.Connection")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection")
Expand Down
35 changes: 32 additions & 3 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2277,7 +2277,7 @@ def override_caplog(request):
@pytest.fixture
def mock_supervisor_comms(monkeypatch):
# for back-compat
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_0_PLUS:
yield None
Expand All @@ -2289,13 +2289,42 @@ def mock_supervisor_comms(monkeypatch):
# core and TaskSDK is finished
if CommsDecoder := getattr(comms, "CommsDecoder", None):
comms = mock.create_autospec(CommsDecoder)
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False)
else:
CommsDecoder = getattr(task_runner, "CommsDecoder")
comms = mock.create_autospec(CommsDecoder)
comms.send = comms.get_message

if AIRFLOW_V_3_2_PLUS:
svcomms = task_runner.SupervisorComms()
old = svcomms.get_comms()
svcomms.set_comms(comms)
yield comms
svcomms.set_comms(old)
else:
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False)
yield comms
yield comms


@pytest.fixture
def mock_unset_supervisor_comms(monkeypatch):
# for back-compat
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_0_PLUS:
yield None
return

from airflow.sdk.execution_time import comms, task_runner

if AIRFLOW_V_3_2_PLUS:
svcomms = task_runner.SupervisorComms()
old = svcomms.get_comms()
svcomms.reset_comms()
yield comms
svcomms.set_comms(old)
else:
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None, raising=False)
yield comms


@pytest.fixture
Expand Down
Loading