1818from __future__ import annotations
1919
2020import importlib
21+ import warnings
2122from typing import TYPE_CHECKING
2223
2324import pymongo
2425import pytest
2526
27+ from airflow .exceptions import AirflowProviderDeprecationWarning
2628from airflow .models import Connection
2729from airflow .providers .mongo .hooks .mongo import MongoHook
28- from airflow .utils import db
2930
3031pytestmark = pytest .mark .db_test
3132
4041 mongomock = None
4142
4243
44+ @pytest .fixture (scope = "module" , autouse = True )
45+ def mongo_connections ():
46+ """Create MongoDB connections which use for testing purpose."""
47+ connections = [
48+ Connection (conn_id = "mongo_default" , conn_type = "mongo" , host = "mongo" , port = 27017 ),
49+ Connection (
50+ conn_id = "mongo_default_with_srv" ,
51+ conn_type = "mongo" ,
52+ host = "mongo" ,
53+ port = 27017 ,
54+ extra = '{"srv": true}' ,
55+ ),
56+ # Mongo establishes connection during initialization, so we need to have this connection
57+ Connection (conn_id = "fake_connection" , conn_type = "mongo" , host = "mongo" , port = 27017 ),
58+ ]
59+
60+ with pytest .MonkeyPatch .context () as mp :
61+ for conn in connections :
62+ mp .setenv (f"AIRFLOW_CONN_{ conn .conn_id .upper ()} " , conn .as_json ())
63+ yield
64+
65+
4366class MongoHookTest (MongoHook ):
4467 """
4568 Extending hook so that a mockmongo collection object can be passed in
4669 to get_collection()
4770 """
4871
49- def __init__ (self , conn_id = "mongo_default" , * args , ** kwargs ):
50- super ().__init__ (conn_id = conn_id , * args , ** kwargs )
72+ def __init__ (self , mongo_conn_id = "mongo_default" , * args , ** kwargs ):
73+ super ().__init__ (mongo_conn_id = mongo_conn_id , * args , ** kwargs )
5174
5275 def get_collection (self , mock_collection , mongo_db = None ):
5376 return mock_collection
@@ -56,24 +79,33 @@ def get_collection(self, mock_collection, mongo_db=None):
5679@pytest .mark .skipif (mongomock is None , reason = "mongomock package not present" )
5780class TestMongoHook :
5881 def setup_method (self ):
59- self .hook = MongoHookTest (conn_id = "mongo_default" , mongo_db = "default " )
82+ self .hook = MongoHookTest (mongo_conn_id = "mongo_default" )
6083 self .conn = self .hook .get_conn ()
61- db .merge_conn (
62- Connection (
63- conn_id = "mongo_default_with_srv" ,
64- conn_type = "mongo" ,
65- host = "mongo" ,
66- port = 27017 ,
67- extra = '{"srv": true}' ,
84+
85+ def test_mongo_conn_id (self ):
86+ with warnings .catch_warnings ():
87+ warnings .simplefilter ("error" , category = AirflowProviderDeprecationWarning )
88+ # Use default "mongo_default"
89+ assert MongoHook ().mongo_conn_id == "mongo_default"
90+ # Positional argument
91+ assert MongoHook ("fake_connection" ).mongo_conn_id == "fake_connection"
92+
93+ warning_message = "Parameter `conn_id` is deprecated"
94+ with pytest .warns (AirflowProviderDeprecationWarning , match = warning_message ):
95+ assert MongoHook (conn_id = "fake_connection" ).mongo_conn_id == "fake_connection"
96+
97+ with pytest .warns (AirflowProviderDeprecationWarning , match = warning_message ):
98+ assert (
99+ MongoHook (conn_id = "fake_connection" , mongo_conn_id = "foo-bar" ).mongo_conn_id
100+ == "fake_connection"
68101 )
69- )
70102
71103 def test_get_conn (self ):
72104 assert self .hook .connection .port == 27017
73105 assert isinstance (self .conn , pymongo .MongoClient )
74106
75107 def test_srv (self ):
76- hook = MongoHook (conn_id = "mongo_default_with_srv" )
108+ hook = MongoHook (mongo_conn_id = "mongo_default_with_srv" )
77109 assert hook .uri .startswith ("mongodb+srv://" )
78110
79111 def test_insert_one (self ):
@@ -333,7 +365,7 @@ def test_distinct_with_filter(self):
333365
334366
335367def test_context_manager ():
336- with MongoHook (conn_id = "mongo_default" , mongo_db = "default " ) as ctx_hook :
368+ with MongoHook (mongo_conn_id = "mongo_default" ) as ctx_hook :
337369 ctx_hook .get_conn ()
338370
339371 assert isinstance (ctx_hook , MongoHook )
0 commit comments