Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions .github/workflows/check-newsfragment-pr-number.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
---
name: Check newsfragment PR number
on: # yamllint disable-line rule:truthy
pull_request:
branches:
- main
- v[0-9]+-[0-9]+-test
- v[0-9]+-[0-9]+-stable
- providers-[a-z]+-?[a-z]*/v[0-9]+-[0-9]+
types: [opened, reopened, synchronize]

permissions:
contents: read
pull-requests: read

concurrency:
group: check-newsfragment-${{ github.event.pull_request.number }}
cancel-in-progress: true

jobs:
check-newsfragment-pr-number:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Check newsfragment PR number
env:
GH_TOKEN: ${{ github.token }}
PR_NUMBER: "${{ github.event.pull_request.number }}"
run: |
# Find newsfragment files whose PR number doesn't match this PR.
# Use the REST API to get file statuses so we can exclude deleted files.
bad=$(gh api "repos/${{ github.repository }}/pulls/${PR_NUMBER}/files" \
--paginate --jq '.[] | select(.status != "removed") | .filename' \
| grep '/newsfragments/.*\.rst$' \
| grep -v "/newsfragments/${PR_NUMBER}\." || true)

if [ -n "$bad" ]; then
echo "::error::Newsfragment PR number mismatch. Expected ${PR_NUMBER} but found: ${bad}"
exit 1
fi
3 changes: 2 additions & 1 deletion task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,8 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
kwargs.setdefault("base_url", "dry-run://server")
else:
kwargs["base_url"] = base_url
kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH)
# Call via the class to avoid binding lru_cache wires to this instance.
kwargs["verify"] = type(self)._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH)

# Set timeout if not explicitly provided
kwargs.setdefault("timeout", API_TIMEOUT)
Expand Down
8 changes: 6 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import collections
import contextlib
import inspect
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
Expand Down Expand Up @@ -197,8 +198,11 @@ async def _async_get_connection(conn_id: str) -> Connection:
for secrets_backend in backends:
try:
# Use async method if available, otherwise wrap sync method
if hasattr(secrets_backend, "aget_connection"):
conn = await secrets_backend.aget_connection(conn_id) # type: ignore[assignment]
# getattr avoids triggering AsyncMock coroutine creation under Python 3.13
async_method = getattr(secrets_backend, "aget_connection", None)
if async_method is not None:
maybe_awaitable = async_method(conn_id)
conn = await maybe_awaitable if inspect.isawaitable(maybe_awaitable) else maybe_awaitable
else:
conn = await sync_to_async(secrets_backend.get_connection)(conn_id) # type: ignore[assignment]

Expand Down
37 changes: 21 additions & 16 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@
from datetime import datetime, timezone
from http import HTTPStatus
from socket import socket, socketpair
from typing import (
TYPE_CHECKING,
BinaryIO,
ClassVar,
NoReturn,
TextIO,
cast,
)
from typing import TYPE_CHECKING, BinaryIO, ClassVar, NoReturn, TextIO, cast
from urllib.parse import urlparse
from uuid import UUID

Expand Down Expand Up @@ -898,17 +891,29 @@ def _remote_logging_conn(client: Client):
# Fetch connection details on-demand without caching the entire API client instance
conn = _fetch_remote_logging_conn(conn_id, client)

if conn:
key = f"AIRFLOW_CONN_{conn_id.upper()}"
old = os.getenv(key)
os.environ[key] = conn.get_uri()
if not conn:
try:
yield
finally:
if old is None:
del os.environ[key]
else:
os.environ[key] = old
# Ensure we don't leak the caller's client when no connection was fetched.
del conn
del client
return

key = f"AIRFLOW_CONN_{conn_id.upper()}"
old = os.getenv(key)
os.environ[key] = conn.get_uri()
try:
yield
finally:
if old is None:
del os.environ[key]
else:
os.environ[key] = old

# Explicitly drop local references so the caller's client can be garbage collected.
del conn
del client


@attrs.define(kw_only=True)
Expand Down
22 changes: 21 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,29 @@ def test_runtime_error_triggers_greenback_fallback(self, mocker, mock_supervisor

# Mock the greenback and asyncio modules that are imported inside the exception handler
mocker.patch("greenback.has_portal", return_value=True)
mock_greenback_await = mocker.patch("greenback.await_", return_value=expected_conn)
mocker.patch("asyncio.current_task")

# Mock greenback.await_ to actually await the coroutine it receives.
# This prevents Python 3.13 RuntimeWarning about unawaited coroutines.
import asyncio

def greenback_await_side_effect(coro):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()

mock_greenback_await = mocker.patch("greenback.await_", side_effect=greenback_await_side_effect)

# Mock aget_connection to return the expected connection directly.
# We need to mock this because the real aget_connection would try to
# use SUPERVISOR_COMMS.asend which is not set up for this test.
async def mock_aget_connection(self, conn_id):
return expected_conn

mocker.patch.object(ExecutionAPISecretsBackend, "aget_connection", mock_aget_connection)

backend = ExecutionAPISecretsBackend()
conn = backend.get_connection("databricks_default")

Expand Down
31 changes: 29 additions & 2 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,35 @@ class TestWatchedSubprocess:
def disable_log_upload(self, spy_agency):
spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False)

# TODO: Investigate and fix it after 3.1.0
@pytest.mark.xfail(reason="Fails on Py 3.12 with multi-threading error only in tests.")
@pytest.fixture(autouse=True)
def use_real_secrets_backends(self, monkeypatch):
"""
Ensure that real secrets backend instances are used instead of mocks.

This prevents Python 3.13 RuntimeWarning when hasattr checks async methods
on mocked backends. The warning occurs because hasattr on AsyncMock creates
unawaited coroutines.

This fixture ensures test isolation when running in parallel with pytest-xdist,
regardless of what other tests patch.
"""
import importlib

import airflow.sdk.execution_time.secrets.execution_api as execution_api_module
from airflow.secrets.environment_variables import EnvironmentVariablesBackend

fresh_execution_backend = importlib.reload(execution_api_module).ExecutionAPISecretsBackend

# Ensure downstream imports see the restored class instead of any AsyncMock left by other tests
import airflow.sdk.execution_time.secrets as secrets_package

monkeypatch.setattr(secrets_package, "ExecutionAPISecretsBackend", fresh_execution_backend)

monkeypatch.setattr(
"airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded",
lambda: [EnvironmentVariablesBackend(), fresh_execution_backend()],
)

def test_reading_from_pipes(self, captured_logs, time_machine, client_with_ti_start):
def subprocess_main():
# This is run in the subprocess!
Expand Down