Skip to content

Commit f33b686

Browse files
authored
Simplify checks for package versions (#37585)
Replaces more complex package version checks with one-liners.
1 parent 492d90c commit f33b686

5 files changed

Lines changed: 19 additions & 44 deletions

File tree

airflow/utils/pydantic.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,14 @@
2424

2525
from __future__ import annotations
2626

27+
from importlib import metadata
2728

28-
def is_pydantic_2_installed() -> bool:
29-
import sys
29+
from packaging import version
3030

31-
from packaging.version import Version
3231

33-
if sys.version_info >= (3, 9):
34-
from importlib.metadata import distribution
35-
else:
36-
from importlib_metadata import distribution
32+
def is_pydantic_2_installed() -> bool:
3733
try:
38-
return Version(distribution("pydantic").version) >= Version("2.0.0")
34+
return version.parse(metadata.version("pydantic")).major == 2
3935
except ImportError:
4036
return False
4137

airflow/utils/sqlalchemy.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import datetime
2323
import json
2424
import logging
25-
from importlib.metadata import version
25+
from importlib import metadata
2626
from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
2727

2828
from dateutil import relativedelta
29-
from packaging.version import Version, parse as parse_version
29+
from packaging import version
3030
from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
3131
from sqlalchemy.dialects import mysql
3232
from sqlalchemy.types import JSON, Text, TypeDecorator
@@ -555,10 +555,5 @@ def get_orm_mapper():
555555
return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper
556556

557557

558-
def _get_lib_major_version(lib_name: str) -> int:
559-
ver: Version = parse_version(version(lib_name))
560-
return ver.major
561-
562-
563558
def is_sqlalchemy_v1() -> bool:
564-
return _get_lib_major_version("sqlalchemy") == 1
559+
return version.parse(metadata.version("sqlalchemy")).major == 1

airflow/utils/timezone.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@
1818
from __future__ import annotations
1919

2020
import datetime as dt
21+
from importlib import metadata
2122
from typing import TYPE_CHECKING, overload
2223

2324
import pendulum
2425
from dateutil.relativedelta import relativedelta
26+
from packaging import version
2527
from pendulum.datetime import DateTime
2628

2729
if TYPE_CHECKING:
2830
from pendulum.tz.timezone import FixedTimezone, Timezone
2931

3032
from airflow.typing_compat import Literal
3133

32-
_PENDULUM3 = pendulum.__version__.startswith("3")
34+
_PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3
3335
# UTC Timezone as a tzinfo instance. Actual value depends on pendulum version:
3436
# - Timezone("UTC") in pendulum 3
3537
# - FixedTimezone(0, "UTC") in pendulum 2

tests/serialization/serializers/test_serializers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import datetime
2020
import decimal
21+
from importlib import metadata
2122
from unittest.mock import patch
2223

2324
import numpy as np
@@ -26,6 +27,7 @@
2627
import pytest
2728
from dateutil.tz import tzutc
2829
from deltalake import DeltaTable
30+
from packaging import version
2931
from pendulum import DateTime
3032
from pendulum.tz.timezone import FixedTimezone, Timezone
3133

@@ -38,7 +40,7 @@
3840
else:
3941
from backports.zoneinfo import ZoneInfo
4042

41-
PENDULUM3 = pendulum.__version__.startswith("3")
43+
PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3
4244

4345

4446
class TestSerializers:

tests/utils/test_sqlalchemy.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from airflow.settings import Session
3636
from airflow.utils.sqlalchemy import (
3737
ExecutorConfigType,
38-
_get_lib_major_version,
3938
ensure_pod_is_valid_after_unpickling,
4039
is_sqlalchemy_v1,
4140
prohibit_commit,
@@ -317,32 +316,13 @@ def test_result_processor_bad_pickled_obj(self):
317316

318317

319318
@pytest.mark.parametrize(
320-
"version_string, expected_major_version",
319+
"mock_version, expected_result",
321320
[
322-
("1.4.22", 1), # Test 1: "1.4.22" parsed as 1
323-
("10.4.22", 10), # Test 2: "10.4.22" not parsed as 1
324-
("invalid", None), # Test 3: Invalid version string
325-
("3.x.x", None), # Test 4: Malformed version
321+
("1.0.0", True), # Test 1: v1 identified as v1
322+
("2.3.4", False), # Test 2: v2 not identified as v1
326323
],
327324
)
328-
def test_get_lib_major_version(version_string, expected_major_version):
329-
with mock.patch("airflow.utils.sqlalchemy.version") as mock_version:
330-
mock_version.return_value = version_string
331-
if expected_major_version is not None:
332-
assert _get_lib_major_version("dummy_module") == expected_major_version
333-
else:
334-
with pytest.raises(ValueError):
335-
_get_lib_major_version("dummy_module")
336-
337-
338-
@pytest.mark.parametrize(
339-
"major_version, expected_result",
340-
[
341-
(1, True), # Test 1: v1 identified as v1
342-
(2, False), # Test 2: v2 not identified as v1
343-
],
344-
)
345-
def test_is_sqlalchemy_v1(major_version, expected_result):
346-
with mock.patch("airflow.utils.sqlalchemy._get_lib_major_version") as mock_get_major_version:
347-
mock_get_major_version.return_value = major_version
325+
def test_is_sqlalchemy_v1(mock_version, expected_result):
326+
with mock.patch("airflow.utils.sqlalchemy.metadata") as mock_metadata:
327+
mock_metadata.version.return_value = mock_version
348328
assert is_sqlalchemy_v1() == expected_result

0 commit comments

Comments
 (0)