Skip to content

Commit 78b179f

Browse files
authored
Switch to mongo_conn_id argument into the MongoHook constructor (#36896)
1 parent 98fb11b commit 78b179f

4 files changed

Lines changed: 78 additions & 25 deletions

File tree

airflow/providers/mongo/hooks/mongo.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
"""Hook for Mongo DB."""
1919
from __future__ import annotations
2020

21+
import warnings
2122
from ssl import CERT_NONE
2223
from typing import TYPE_CHECKING, Any, overload
2324
from urllib.parse import quote_plus, urlunsplit
2425

2526
import pymongo
2627
from pymongo import MongoClient, ReplaceOne
2728

29+
from airflow.exceptions import AirflowProviderDeprecationWarning
2830
from airflow.hooks.base import BaseHook
2931

3032
if TYPE_CHECKING:
@@ -57,10 +59,19 @@ class MongoHook(BaseHook):
5759
conn_type = "mongo"
5860
hook_name = "MongoDB"
5961

60-
def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None:
62+
def __init__(self, mongo_conn_id: str = default_conn_name, *args, **kwargs) -> None:
6163
super().__init__(logger_name=kwargs.pop("logger_name", None))
62-
self.mongo_conn_id = conn_id
63-
self.connection = self.get_connection(conn_id)
64+
if conn_id := kwargs.pop("conn_id", None):
65+
warnings.warn(
66+
"Parameter `conn_id` is deprecated and will be removed in a future releases. "
67+
"Please use `mongo_conn_id` instead.",
68+
AirflowProviderDeprecationWarning,
69+
stacklevel=2,
70+
)
71+
mongo_conn_id = conn_id
72+
73+
self.mongo_conn_id = mongo_conn_id
74+
self.connection = self.get_connection(self.mongo_conn_id)
6475
self.extras = self.connection.extra_dejson.copy()
6576
self.client: MongoClient | None = None
6677
self.uri = self._create_uri()

airflow/providers/mongo/sensors/mongo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,5 @@ def poke(self, context: Context) -> bool:
6262
self.log.info(
6363
"Sensor check existence of the document that matches the following query: %s", self.query
6464
)
65-
hook = MongoHook(self.mongo_conn_id)
65+
hook = MongoHook(mongo_conn_id=self.mongo_conn_id)
6666
return hook.find(self.collection, self.query, mongo_db=self.mongo_db, find_one=True) is not None

tests/integration/providers/mongo/sensors/test_mongo.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,32 @@
2323
from airflow.models.dag import DAG
2424
from airflow.providers.mongo.hooks.mongo import MongoHook
2525
from airflow.providers.mongo.sensors.mongo import MongoSensor
26-
from airflow.utils import db, timezone
26+
from airflow.utils import timezone
2727

2828
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
2929

3030

31+
@pytest.fixture(scope="module", autouse=True)
32+
def mongo_connections():
33+
"""Create MongoDB connections which use for testing purpose."""
34+
connections = [
35+
Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017),
36+
Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", port=27017, schema="test"),
37+
]
38+
39+
with pytest.MonkeyPatch.context() as mp:
40+
for conn in connections:
41+
mp.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.as_json())
42+
yield
43+
44+
3145
@pytest.mark.integration("mongo")
3246
class TestMongoSensor:
3347
def setup_method(self):
34-
db.merge_conn(
35-
Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", port=27017, schema="test")
36-
)
37-
3848
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
3949
self.dag = DAG("test_dag_id", default_args=args)
4050

41-
hook = MongoHook("mongo_test")
51+
hook = MongoHook(mongo_conn_id="mongo_test")
4252
hook.insert_one("foo", {"bar": "baz"})
4353

4454
self.sensor = MongoSensor(
@@ -53,7 +63,7 @@ def test_poke(self):
5363
assert self.sensor.poke(None)
5464

5565
def test_sensor_with_db(self):
56-
hook = MongoHook("mongo_test")
66+
hook = MongoHook(mongo_conn_id="mongo_test")
5767
hook.insert_one("nontest", {"1": "2"}, mongo_db="nontest")
5868

5969
sensor = MongoSensor(

tests/providers/mongo/hooks/test_mongo.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
from __future__ import annotations
1919

2020
import importlib
21+
import warnings
2122
from typing import TYPE_CHECKING
2223

2324
import pymongo
2425
import pytest
2526

27+
from airflow.exceptions import AirflowProviderDeprecationWarning
2628
from airflow.models import Connection
2729
from airflow.providers.mongo.hooks.mongo import MongoHook
28-
from airflow.utils import db
2930

3031
pytestmark = pytest.mark.db_test
3132

@@ -40,14 +41,36 @@
4041
mongomock = None
4142

4243

44+
@pytest.fixture(scope="module", autouse=True)
45+
def mongo_connections():
46+
"""Create MongoDB connections which use for testing purpose."""
47+
connections = [
48+
Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017),
49+
Connection(
50+
conn_id="mongo_default_with_srv",
51+
conn_type="mongo",
52+
host="mongo",
53+
port=27017,
54+
extra='{"srv": true}',
55+
),
56+
# Mongo establishes connection during initialization, so we need to have this connection
57+
Connection(conn_id="fake_connection", conn_type="mongo", host="mongo", port=27017),
58+
]
59+
60+
with pytest.MonkeyPatch.context() as mp:
61+
for conn in connections:
62+
mp.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.as_json())
63+
yield
64+
65+
4366
class MongoHookTest(MongoHook):
4467
"""
4568
Extending hook so that a mockmongo collection object can be passed in
4669
to get_collection()
4770
"""
4871

49-
def __init__(self, conn_id="mongo_default", *args, **kwargs):
50-
super().__init__(conn_id=conn_id, *args, **kwargs)
72+
def __init__(self, mongo_conn_id="mongo_default", *args, **kwargs):
73+
super().__init__(mongo_conn_id=mongo_conn_id, *args, **kwargs)
5174

5275
def get_collection(self, mock_collection, mongo_db=None):
5376
return mock_collection
@@ -56,24 +79,33 @@ def get_collection(self, mock_collection, mongo_db=None):
5679
@pytest.mark.skipif(mongomock is None, reason="mongomock package not present")
5780
class TestMongoHook:
5881
def setup_method(self):
59-
self.hook = MongoHookTest(conn_id="mongo_default", mongo_db="default")
82+
self.hook = MongoHookTest(mongo_conn_id="mongo_default")
6083
self.conn = self.hook.get_conn()
61-
db.merge_conn(
62-
Connection(
63-
conn_id="mongo_default_with_srv",
64-
conn_type="mongo",
65-
host="mongo",
66-
port=27017,
67-
extra='{"srv": true}',
84+
85+
def test_mongo_conn_id(self):
86+
with warnings.catch_warnings():
87+
warnings.simplefilter("error", category=AirflowProviderDeprecationWarning)
88+
# Use default "mongo_default"
89+
assert MongoHook().mongo_conn_id == "mongo_default"
90+
# Positional argument
91+
assert MongoHook("fake_connection").mongo_conn_id == "fake_connection"
92+
93+
warning_message = "Parameter `conn_id` is deprecated"
94+
with pytest.warns(AirflowProviderDeprecationWarning, match=warning_message):
95+
assert MongoHook(conn_id="fake_connection").mongo_conn_id == "fake_connection"
96+
97+
with pytest.warns(AirflowProviderDeprecationWarning, match=warning_message):
98+
assert (
99+
MongoHook(conn_id="fake_connection", mongo_conn_id="foo-bar").mongo_conn_id
100+
== "fake_connection"
68101
)
69-
)
70102

71103
def test_get_conn(self):
72104
assert self.hook.connection.port == 27017
73105
assert isinstance(self.conn, pymongo.MongoClient)
74106

75107
def test_srv(self):
76-
hook = MongoHook(conn_id="mongo_default_with_srv")
108+
hook = MongoHook(mongo_conn_id="mongo_default_with_srv")
77109
assert hook.uri.startswith("mongodb+srv://")
78110

79111
def test_insert_one(self):
@@ -333,7 +365,7 @@ def test_distinct_with_filter(self):
333365

334366

335367
def test_context_manager():
336-
with MongoHook(conn_id="mongo_default", mongo_db="default") as ctx_hook:
368+
with MongoHook(mongo_conn_id="mongo_default") as ctx_hook:
337369
ctx_hook.get_conn()
338370

339371
assert isinstance(ctx_hook, MongoHook)

0 commit comments

Comments
 (0)