Skip to content
20 changes: 15 additions & 5 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections.abc import Container
from contextlib import AbstractContextManager
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from kubernetes.client import CoreV1Api, models as k8s
from slugify import slugify
Expand Down Expand Up @@ -59,6 +59,7 @@
get_container_termination_message,
)
from airflow.settings import pod_mutation_hook
from airflow.typing_compat import Literal
from airflow.utils import yaml
from airflow.utils.helpers import prune_dict, validate_key
from airflow.utils.timezone import utcnow
Expand Down Expand Up @@ -175,6 +176,10 @@ class KubernetesPodOperator(BaseOperator):
:param labels: labels to apply to the Pod. (templated)
:param startup_timeout_seconds: timeout in seconds to startup the pod.
:param get_logs: get the stdout of the base container as logs of the tasks.
:param container_logs: list of containers whose logs will be published to stdout
Takes a sequence of containers, a single container name or True. If True,
all the containers logs are published. Works in conjunction with get_logs param.
The default value is the base container.
:param image_pull_policy: Specify a policy to cache or always pull an image.
:param annotations: non-identifying metadata you can attach to the Pod.
Can be a large range of data, and can include characters
Expand Down Expand Up @@ -271,6 +276,7 @@ def __init__(
reattach_on_restart: bool = True,
startup_timeout_seconds: int = 120,
get_logs: bool = True,
container_logs: Iterable[str] | str | Literal[True] = BASE_CONTAINER_NAME,
image_pull_policy: str | None = None,
annotations: dict | None = None,
container_resources: k8s.V1ResourceRequirements | None = None,
Expand Down Expand Up @@ -342,6 +348,11 @@ def __init__(
self.cluster_context = cluster_context
self.reattach_on_restart = reattach_on_restart
self.get_logs = get_logs
self.container_logs = container_logs
if self.container_logs == KubernetesPodOperator.BASE_CONTAINER_NAME:
self.container_logs = (
base_container_name if base_container_name else KubernetesPodOperator.BASE_CONTAINER_NAME
)
self.image_pull_policy = image_pull_policy
self.node_selector = node_selector or {}
self.annotations = annotations or {}
Expand Down Expand Up @@ -551,11 +562,10 @@ def execute_sync(self, context: Context):
self.await_pod_start(pod=self.pod)

if self.get_logs:
self.pod_manager.fetch_container_logs(
self.pod_manager.fetch_requested_container_logs(
pod=self.pod,
container_name=self.base_container_name,
follow=True,
post_termination_timeout=self.POST_TERMINATION_TIMEOUT,
container_logs=self.container_logs,
follow_logs=True,
)
else:
self.pod_manager.await_container_completion(
Expand Down
95 changes: 91 additions & 4 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import math
import time
import warnings
from collections.abc import Iterable
from contextlib import closing, suppress
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -42,7 +43,7 @@

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.kubernetes.pod_generator import PodDefaults
from airflow.typing_compat import Protocol
from airflow.typing_compat import Literal, Protocol
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timezone import utcnow

Expand Down Expand Up @@ -122,6 +123,17 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool:
return container_status.state.running is not None


def container_is_completed(pod: V1Pod, container_name: str) -> bool:
"""
Examines V1Pod ``pod`` to determine whether ``container_name`` is completed.
If that container is present and completed, returns True. Returns False otherwise.
"""
container_status = get_container_status(pod, container_name)
if not container_status:
return False
return container_status.state.terminated is not None


def container_is_terminated(pod: V1Pod, container_name: str) -> bool:
"""
Examines V1Pod ``pod`` to determine whether ``container_name`` is terminated.
Expand Down Expand Up @@ -378,11 +390,12 @@ def consume_logs(
for raw_line in logs:
line = raw_line.decode("utf-8", errors="backslashreplace")
timestamp, message = self.parse_log_line(line)
self.log.info(message)
self.log.info("[%s] %s", container_name, message)
except BaseHTTPError as e:
self.log.warning(
"Reading of logs interrupted with error %r; will retry. "
"Reading of logs interrupted for container %r with error %r; will retry. "
"Set log level to DEBUG for traceback.",
container_name,
e,
)
self.log.debug(
Expand Down Expand Up @@ -412,14 +425,78 @@ def consume_logs(
)
time.sleep(1)

def fetch_requested_container_logs(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this entire function can be transformed into a generator function instead and use yield to get rid of all the append calls.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, it is simpler to control the flow when we have these append calls. I would like to keep it this way if that is ok by you

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we need to use the provided post_termination_timeout in this method as we did with fetch_container_logs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I quite understand this comment. We are calling the fetch_container_logs from fetch_requested_container_logs which will take care of post_termination_timeout right?

self, pod: V1Pod, container_logs: Iterable[str] | str | Literal[True], follow_logs=False
) -> list[PodLoggingStatus]:
"""
Follow the logs of containers in the pod specified by input parameter and publish
it to airflow logging. Returns when all the containers exit.
"""
pod_logging_statuses = []
all_containers = self.get_container_names(pod)
if len(all_containers) == 0:
self.log.error("Could not retrieve containers for the pod: %s", pod.metadata.name)
else:
if isinstance(container_logs, str):
# fetch logs only for requested container if only one container is provided
if container_logs in all_containers:
status = self.fetch_container_logs(
pod=pod, container_name=container_logs, follow=follow_logs
)
pod_logging_statuses.append(status)
else:
self.log.error(
"container %s whose logs were requested not found in the pod %s",
container_logs,
pod.metadata.name,
)
elif isinstance(container_logs, bool):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to keep this check as bool so that we can reject/filter out invalid/unsupported types too here https://github.com/apache/airflow/pull/31663/files#diff-6900da9281d8404b14da0815f2b37350f3148b1b63449928b952744e6711e7e7R475-R478
@uranusjr

# if True is provided, get logs for all the containers
if container_logs is True:
for container_name in all_containers:
status = self.fetch_container_logs(
pod=pod, container_name=container_name, follow=follow_logs
)
pod_logging_statuses.append(status)
else:
self.log.error(
"False is not a valid value for container_logs",
)
Comment on lines 461 to 464
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this branch needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we need it since we support only Literal[True] when it comes to boolean. Refer to my unit tests. It adds False too as an invalid test case

else:
# if a sequence of containers are provided, iterate for every container in the pod
if isinstance(container_logs, Iterable):
for container in container_logs:
if container in all_containers:
status = self.fetch_container_logs(
pod=pod, container_name=container, follow=follow_logs
)
pod_logging_statuses.append(status)
else:
self.log.error(
"Container %s whose logs were requests not found in the pod %s",
container,
pod.metadata.name,
)
else:
self.log.error(
"Invalid type %s specified for container names input parameter", type(container_logs)
)
Comment thread
potiuk marked this conversation as resolved.
Outdated

return pod_logging_statuses

def await_container_completion(self, pod: V1Pod, container_name: str) -> None:
"""
Waits for the given container in the given pod to be completed.

:param pod: pod spec that will be monitored
:param container_name: name of the container within the pod to monitor
"""
while not self.container_is_terminated(pod=pod, container_name=container_name):
while True:
remote_pod = self.read_pod(pod)
terminated = container_is_completed(remote_pod, container_name)
if terminated:
break
self.log.info("Waiting for container '%s' state to be completed", container_name)
time.sleep(1)

def await_pod_completion(self, pod: V1Pod) -> V1Pod:
Expand Down Expand Up @@ -512,6 +589,16 @@ def read_pod_logs(
post_termination_timeout=post_termination_timeout,
)

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def get_container_names(self, pod: V1Pod) -> list[str]:
"""Return container names from the POD except for the airflow-xcom-sidecar container."""
pod_info = self.read_pod(pod)
return [
container_spec.name
for container_spec in pod_info.spec.containers
if container_spec.name != PodDefaults.SIDECAR_CONTAINER_NAME
]

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod_events(self, pod: V1Pod) -> CoreV1EventList:
"""Reads events from the POD."""
Expand Down
2 changes: 1 addition & 1 deletion kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def test_volume_mount(self, mock_get_connection):
)
context = create_context(k)
k.execute(context=context)
mock_logger.info.assert_any_call("retrieved from mount")
mock_logger.info.assert_any_call("[%s] %s", "base", "retrieved from mount")
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod["spec"]["containers"][0]["args"] = args
self.expected_pod["spec"]["containers"][0]["volumeMounts"] = [
Expand Down
40 changes: 40 additions & 0 deletions tests/providers/cncf/kubernetes/utils/test_pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,46 @@ def test_fetch_container_done(self, logs_available, container_running, follow):
assert ret.last_log_time is None
assert ret.running is False

# adds all valid types for container_logs
@pytest.mark.parametrize("follow", [True, False])
@pytest.mark.parametrize("container_logs", ["base", "alpine", True, ["base", "alpine"]])
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
def test_fetch_requested_container_logs(self, container_is_running, container_logs, follow):
mock_pod = MagicMock()
self.pod_manager.read_pod = MagicMock()
self.pod_manager.get_container_names = MagicMock()
self.pod_manager.get_container_names.return_value = ["base", "alpine"]
container_is_running.return_value = False
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.MagicMock(
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)

ret_values = self.pod_manager.fetch_requested_container_logs(
pod=mock_pod, container_logs=container_logs, follow_logs=follow
)
for ret in ret_values:
assert ret.running is False

# adds all invalid types for container_logs
@pytest.mark.parametrize("container_logs", [1, None, 6.8, False])
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
def test_fetch_requested_container_logs_invalid(self, container_running, container_logs):
mock_pod = MagicMock()
self.pod_manager.read_pod = MagicMock()
self.pod_manager.get_container_names = MagicMock()
self.pod_manager.get_container_names.return_value = ["base", "alpine"]
container_running.return_value = False
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.MagicMock(
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)

ret_values = self.pod_manager.fetch_requested_container_logs(
pod=mock_pod,
container_logs=container_logs,
)

assert len(ret_values) == 0

@mock.patch("pendulum.now")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodLogsConsumer.logs_available")
Expand Down