Skip to content

Commit 4a7261e

Browse files
committed
Add rudimentary support for psycopg3
Note: the new code will become active only if sqla2 and psycopg3 are both detected. Also: - Refactor the postgres tests to: 1) use the `mocker` fixture; and 2) stop using the `caplog` fixture.
1 parent 9b46d76 commit 4a7261e

16 files changed

Lines changed: 1080 additions & 337 deletions

File tree

airflow-core/src/airflow/settings.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@
4343
from airflow.utils.orm_event_handlers import setup_event_handlers
4444
from airflow.utils.sqlalchemy import is_sqlalchemy_v1
4545

46+
USE_PSYCOPG3: bool
47+
try:
48+
from importlib.util import find_spec
49+
50+
is_psycopg3 = find_spec("psycopg") is not None
51+
52+
USE_PSYCOPG3 = is_psycopg3 and not is_sqlalchemy_v1()
53+
except (ImportError, ModuleNotFoundError):
54+
USE_PSYCOPG3 = False
55+
4656
if TYPE_CHECKING:
4757
from sqlalchemy.engine import Engine
4858

@@ -426,12 +436,17 @@ def clean_in_fork():
426436
register_at_fork(after_in_child=clean_in_fork)
427437

428438

429-
DEFAULT_ENGINE_ARGS = {
430-
"postgresql": {
431-
"executemany_mode": "values_plus_batch",
432-
"executemany_values_page_size" if is_sqlalchemy_v1() else "insertmanyvalues_page_size": 10000,
433-
"executemany_batch_page_size": 2000,
434-
},
439+
DEFAULT_ENGINE_ARGS: dict[str, dict[str, Any]] = {
440+
"postgresql": (
441+
{
442+
"executemany_values_page_size" if is_sqlalchemy_v1() else "insertmanyvalues_page_size": 10000,
443+
}
444+
| (
445+
{}
446+
if USE_PSYCOPG3
447+
else {"executemany_mode": "values_plus_batch", "executemany_batch_page_size": 2000}
448+
)
449+
)
435450
}
436451

437452

airflow-core/src/airflow/utils/db.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@
6161
from airflow.utils.session import NEW_SESSION, provide_session
6262
from airflow.utils.task_instance_session import get_current_task_instance_session
6363

64+
USE_PSYCOPG3: bool
65+
try:
66+
from importlib.util import find_spec
67+
68+
import sqlalchemy
69+
from packaging.version import Version
70+
71+
is_psycopg3 = find_spec("psycopg") is not None
72+
sqlalchemy_version = Version(sqlalchemy.__version__)
73+
is_sqla2 = (sqlalchemy_version.major, sqlalchemy_version.minor, sqlalchemy_version.micro) >= (2, 0, 0)
74+
75+
USE_PSYCOPG3 = is_psycopg3 and is_sqla2
76+
except (ImportError, ModuleNotFoundError):
77+
USE_PSYCOPG3 = False
78+
6479
if TYPE_CHECKING:
6580
from alembic.runtime.environment import EnvironmentContext
6681
from alembic.script import ScriptDirectory
@@ -1284,15 +1299,28 @@ def create_global_lock(
12841299
dialect = conn.dialect
12851300
try:
12861301
if dialect.name == "postgresql":
1287-
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
1288-
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
1302+
if USE_PSYCOPG3:
1303+
# psycopg3 doesn't support parameters for `SET`. Use `set_config` instead.
1304+
# The timeout value must be passed as a string of milliseconds.
1305+
conn.execute(
1306+
text("SELECT set_config('lock_timeout', :timeout, false)"),
1307+
{"timeout": str(lock_timeout)},
1308+
)
1309+
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
1310+
else:
1311+
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
1312+
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
12891313
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
12901314
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})
12911315

12921316
yield
12931317
finally:
12941318
if dialect.name == "postgresql":
1295-
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
1319+
if USE_PSYCOPG3:
1320+
# Use set_config() to reset the timeout to its default (0 = off/wait forever).
1321+
conn.execute(text("SELECT set_config('lock_timeout', '0', false)"))
1322+
else:
1323+
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
12961324
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
12971325
if not unlocked:
12981326
raise RuntimeError("Error releasing DB lock!")

airflow-core/tests/unit/always/test_connection.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from airflow.models import Connection, crypto
3333
from airflow.sdk import BaseHook
3434

35-
from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4
35+
from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4, SQLALCHEMY_V_2_0
3636

3737
sqlite = pytest.importorskip("airflow.providers.sqlite.hooks.sqlite")
3838

@@ -679,10 +679,20 @@ def test_env_var_priority(self, mock_supervisor_comms):
679679
def test_dbapi_get_uri(self):
680680
conn = BaseHook.get_connection(conn_id="test_uri")
681681
hook = conn.get_hook()
682-
assert hook.get_uri() == "postgresql://username:password@ec2.compute.com:5432/the_database"
682+
683+
ppg3_mode: bool = SQLALCHEMY_V_2_0 and "psycopg" in hook.get_uri()
684+
if ppg3_mode:
685+
assert (
686+
hook.get_uri() == "postgresql+psycopg://username:password@ec2.compute.com:5432/the_database"
687+
)
688+
else:
689+
assert hook.get_uri() == "postgresql://username:password@ec2.compute.com:5432/the_database"
683690
conn2 = BaseHook.get_connection(conn_id="test_uri_no_creds")
684691
hook2 = conn2.get_hook()
685-
assert hook2.get_uri() == "postgresql://ec2.compute.com/the_database"
692+
if ppg3_mode:
693+
assert hook2.get_uri() == "postgresql+psycopg://ec2.compute.com/the_database"
694+
else:
695+
assert hook2.get_uri() == "postgresql://ec2.compute.com/the_database"
686696

687697
@mock.patch.dict(
688698
"os.environ",
@@ -695,7 +705,12 @@ def test_dbapi_get_sqlalchemy_engine(self):
695705
conn = BaseHook.get_connection(conn_id="test_uri")
696706
hook = conn.get_hook()
697707
engine = hook.get_sqlalchemy_engine()
698-
expected = "postgresql://username:password@ec2.compute.com:5432/the_database"
708+
709+
if SQLALCHEMY_V_2_0 and "psycopg" in hook.get_uri():
710+
expected = "postgresql+psycopg://username:password@ec2.compute.com:5432/the_database"
711+
else:
712+
expected = "postgresql://username:password@ec2.compute.com:5432/the_database"
713+
699714
assert isinstance(engine, sqlalchemy.engine.Engine)
700715
if SQLALCHEMY_V_1_4:
701716
assert str(engine.url) == expected

airflow-core/tests/unit/cli/commands/test_db_command.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,27 @@ def test_cli_shell_postgres(self, mock_execute_interactive):
238238
"PGUSER": "postgres",
239239
}
240240

241+
@mock.patch("airflow.cli.commands.db_command.execute_interactive")
242+
@mock.patch(
243+
"airflow.cli.commands.db_command.settings.engine.url",
244+
make_url("postgresql+psycopg://postgres:airflow@postgres:5432/airflow"),
245+
)
246+
def test_cli_shell_postgres_ppg3(self, mock_execute_interactive):
247+
pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.")
248+
249+
db_command.shell(self.parser.parse_args(["db", "shell"]))
250+
mock_execute_interactive.assert_called_once_with(["psql"], env=mock.ANY)
251+
_, kwargs = mock_execute_interactive.call_args
252+
env = kwargs["env"]
253+
postgres_env = {k: v for k, v in env.items() if k.startswith("PG")}
254+
assert postgres_env == {
255+
"PGDATABASE": "airflow",
256+
"PGHOST": "postgres",
257+
"PGPASSWORD": "airflow",
258+
"PGPORT": "5432",
259+
"PGUSER": "postgres",
260+
}
261+
241262
@mock.patch("airflow.cli.commands.db_command.execute_interactive")
242263
@mock.patch(
243264
"airflow.cli.commands.db_command.settings.engine.url",
@@ -257,6 +278,27 @@ def test_cli_shell_postgres_without_port(self, mock_execute_interactive):
257278
"PGUSER": "postgres",
258279
}
259280

281+
@mock.patch("airflow.cli.commands.db_command.execute_interactive")
282+
@mock.patch(
283+
"airflow.cli.commands.db_command.settings.engine.url",
284+
make_url("postgresql+psycopg://postgres:airflow@postgres/airflow"),
285+
)
286+
def test_cli_shell_postgres_without_port_ppg3(self, mock_execute_interactive):
287+
pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.")
288+
289+
db_command.shell(self.parser.parse_args(["db", "shell"]))
290+
mock_execute_interactive.assert_called_once_with(["psql"], env=mock.ANY)
291+
_, kwargs = mock_execute_interactive.call_args
292+
env = kwargs["env"]
293+
postgres_env = {k: v for k, v in env.items() if k.startswith("PG")}
294+
assert postgres_env == {
295+
"PGDATABASE": "airflow",
296+
"PGHOST": "postgres",
297+
"PGPASSWORD": "airflow",
298+
"PGPORT": "5432",
299+
"PGUSER": "postgres",
300+
}
301+
260302
@mock.patch(
261303
"airflow.cli.commands.db_command.settings.engine.url",
262304
make_url("invalid+psycopg2://postgres:airflow@postgres/airflow"),
@@ -265,6 +307,16 @@ def test_cli_shell_invalid(self):
265307
with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg2"):
266308
db_command.shell(self.parser.parse_args(["db", "shell"]))
267309

310+
@mock.patch(
311+
"airflow.cli.commands.db_command.settings.engine.url",
312+
make_url("invalid+psycopg://postgres:airflow@postgres/airflow"),
313+
)
314+
def test_cli_shell_invalid_ppg3(self):
315+
pytest.importorskip("psycopg", reason="Test only runs when psycopg v3 is installed.")
316+
317+
with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg"):
318+
db_command.shell(self.parser.parse_args(["db", "shell"]))
319+
268320
@pytest.mark.parametrize(
269321
"args, match",
270322
[

devel-common/src/docs/utils/conf_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def get_autodoc_mock_imports() -> list[str]:
248248
"pandas_gbq",
249249
"paramiko",
250250
"pinotdb",
251+
"psycopg",
251252
"psycopg2",
252253
"pydruid",
253254
"pyhive",

providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,9 @@ def _run_command(self, cur, sql_statement, parameters):
810810
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
811811

812812
if parameters:
813+
# If we're using psycopg3, we might need to handle parameters differently
814+
if hasattr(cur, "__module__") and "psycopg" in cur.__module__ and isinstance(parameters, list):
815+
parameters = tuple(parameters)
813816
cur.execute(sql_statement, parameters)
814817
else:
815818
cur.execute(sql_statement)

providers/common/sql/tests/unit/common/sql/sensors/test_sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def test_sql_sensor_postgres(self):
8080
op2 = SqlSensor(
8181
task_id="sql_sensor_check_2",
8282
conn_id="postgres_default",
83-
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
84-
parameters=["table_name"],
83+
sql="SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = %s",
84+
parameters=["information_schema"],
8585
dag=self.dag,
8686
)
8787
op2.execute({})

providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,9 +1175,9 @@ def cleanup_database_hook(self) -> None:
11751175
raise ValueError("The db_hook should be set")
11761176
if not isinstance(self.db_hook, PostgresHook):
11771177
raise ValueError(f"The db_hook should be PostgresHook and is {type(self.db_hook)}")
1178-
conn = getattr(self.db_hook, "conn")
1179-
if conn and conn.notices:
1180-
for output in self.db_hook.conn.notices:
1178+
conn = getattr(self.db_hook, "conn", None)
1179+
if conn and hasattr(conn, "notices") and conn.notices:
1180+
for output in conn.notices:
11811181
self.log.info(output)
11821182

11831183
def reserve_free_tcp_port(self) -> None:

providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from slugify import slugify
3232

3333
from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
34-
from airflow.providers.postgres.hooks.postgres import PostgresHook
34+
from airflow.providers.postgres.hooks.postgres import USE_PSYCOPG3, PostgresHook
3535

3636
if TYPE_CHECKING:
3737
from airflow.providers.openlineage.extractors import OperatorLineage
@@ -52,9 +52,20 @@ def __init__(self, cursor):
5252
self.initialized = False
5353

5454
def __iter__(self):
55+
"""Make the cursor iterable."""
5556
return self
5657

5758
def __next__(self):
59+
"""Fetch next row from the cursor."""
60+
if USE_PSYCOPG3:
61+
if self.rows:
62+
return self.rows.pop()
63+
self.initialized = True
64+
row = self.cursor.fetchone()
65+
if row is None:
66+
raise StopIteration
67+
return row
68+
# psycopg2
5869
if self.rows:
5970
return self.rows.pop()
6071
self.initialized = True
@@ -141,13 +152,29 @@ def db_hook(self) -> PostgresHook:
141152
return PostgresHook(postgres_conn_id=self.postgres_conn_id)
142153

143154
def query(self):
144-
"""Query Postgres and returns a cursor to the results."""
155+
"""Execute the query and return a cursor."""
145156
conn = self.db_hook.get_conn()
146-
cursor = conn.cursor(name=self._unique_name())
147-
cursor.execute(self.sql, self.parameters)
148-
if self.use_server_side_cursor:
149-
cursor.itersize = self.cursor_itersize
150-
return _PostgresServerSideCursorDecorator(cursor)
157+
158+
if USE_PSYCOPG3:
159+
from psycopg.types.json import register_default_adapters
160+
161+
# Register JSON handlers for this connection if not already done
162+
register_default_adapters(conn)
163+
164+
if self.use_server_side_cursor:
165+
cursor_name = f"airflow_{self.task_id.replace('-', '_')}_{uuid.uuid4().hex}"[:63]
166+
cursor = conn.cursor(name=cursor_name)
167+
cursor.itersize = self.cursor_itersize
168+
cursor.execute(self.sql, self.parameters)
169+
return _PostgresServerSideCursorDecorator(cursor)
170+
cursor = conn.cursor()
171+
cursor.execute(self.sql, self.parameters)
172+
else:
173+
cursor = conn.cursor(name=self._unique_name())
174+
cursor.execute(self.sql, self.parameters)
175+
if self.use_server_side_cursor:
176+
cursor.itersize = self.cursor_itersize
177+
return _PostgresServerSideCursorDecorator(cursor)
151178
return cursor
152179

153180
def field_to_bigquery(self, field) -> dict[str, str]:
@@ -182,8 +209,14 @@ def convert_type(self, value, schema_type, stringify_dict=True):
182209
hours=formatted_time.tm_hour, minutes=formatted_time.tm_min, seconds=formatted_time.tm_sec
183210
)
184211
return str(time_delta)
185-
if stringify_dict and isinstance(value, dict):
186-
return json.dumps(value)
212+
if stringify_dict:
213+
if USE_PSYCOPG3:
214+
from psycopg.types.json import Json
215+
216+
if isinstance(value, (dict, Json)):
217+
return json.dumps(value)
218+
elif isinstance(value, dict):
219+
return json.dumps(value)
187220
if isinstance(value, Decimal):
188221
return float(value)
189222
return value

providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,17 @@ def call_get_conn():
383383
"postgresql+psycopg2://login:password@host:1234/schema",
384384
id="sqlalchemy-scheme-with-driver",
385385
),
386+
pytest.param(
387+
{
388+
"conn_params": {
389+
"extra": json.dumps(
390+
{"sqlalchemy_scheme": "postgresql", "sqlalchemy_driver": "psycopg"}
391+
)
392+
}
393+
},
394+
"postgresql+psycopg://login:password@host:1234/schema",
395+
id="sqlalchemy-scheme-with-driver-ppg3",
396+
),
386397
pytest.param(
387398
{
388399
"login": "user@domain",

0 commit comments

Comments
 (0)