Skip to content

Commit 89a8da6

Browse files
committed
Add optional 'location' parameter to the BigQueryInsertJobTrigger
1 parent cdd1a48 commit 89a8da6

7 files changed

Lines changed: 103 additions & 39 deletions

File tree

airflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from __future__ import annotations
2222

23+
import asyncio
2324
import json
2425
import logging
2526
import re
@@ -3242,16 +3243,58 @@ async def get_job_instance(
32423243
session=cast(Session, session),
32433244
)
32443245

3245-
async def get_job_status(self, job_id: str | None, project_id: str | None = None) -> dict[str, str]:
3246-
async with ClientSession() as s:
3247-
job_client = await self.get_job_instance(project_id, job_id, s)
3248-
job = await job_client.get_job()
3249-
status = job.get("status", {})
3250-
if status["state"] == "DONE":
3251-
if "errorResult" in status:
3252-
return {"status": "error", "message": status["errorResult"]["message"]}
3253-
return {"status": "success", "message": "Job completed"}
3254-
return {"status": status["state"].lower(), "message": "Job running"}
3246+
async def _get_job(
3247+
self, job_id: str | None, project_id: str | None = None, location: str | None = None
3248+
) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
3249+
"""
3250+
Get BigQuery job by its ID, project ID and location.
3251+
3252+
WARNING.
3253+
This is a temporary workaround for issues below, and it's not intended to be used elsewhere!
3254+
https://github.com/apache/airflow/issues/35833
3255+
https://github.com/talkiq/gcloud-aio/issues/584
3256+
3257+
This method was developed, because neither the `google-cloud-bigquery` nor the `gcloud-aio-bigquery`
3258+
provides asynchronous access to a BigQuery jobs with location parameter. That's why this method wraps
3259+
synchronous client call with the event loop's run_in_executor() method.
3260+
3261+
This workaround must be deleted along with the method _get_job_sync() and replaced by more robust and
3262+
cleaner solution in one of two cases:
3263+
1. The `google-cloud-bigquery` library provides async client with get_job method, that supports
3264+
optional parameter `location`
3265+
2. The `gcloud-aio-bigquery` library supports the `location` parameter in get_job() method.
3266+
"""
3267+
loop = asyncio.get_event_loop()
3268+
job = await loop.run_in_executor(None, self._get_job_sync, job_id, project_id, location)
3269+
return job
3270+
3271+
def _get_job_sync(self, job_id, project_id, location):
3272+
"""
3273+
Get BigQuery job by its ID, project ID and location synchronously.
3274+
3275+
WARNING
3276+
This is a temporary workaround for issues below, and it's not intended to be used elsewhere!
3277+
https://github.com/apache/airflow/issues/35833
3278+
https://github.com/talkiq/gcloud-aio/issues/584
3279+
3280+
This workaround must be deleted along with the method _get_job() and replaced by more robust and
3281+
cleaner solution in one of two cases:
3282+
1. The `google-cloud-bigquery` library provides async client with get_job method, that supports
3283+
optional parameter `location`
3284+
2. The `gcloud-aio-bigquery` library supports the `location` parameter in get_job() method.
3285+
"""
3286+
hook = BigQueryHook(**self._hook_kwargs)
3287+
return hook.get_job(job_id=job_id, project_id=project_id, location=location)
3288+
3289+
async def get_job_status(
3290+
self, job_id: str | None, project_id: str | None = None, location: str | None = None
3291+
) -> dict[str, str]:
3292+
job = await self._get_job(job_id=job_id, project_id=project_id, location=location)
3293+
if job.state == "DONE":
3294+
if job.error_result:
3295+
return {"status": "error", "message": job.error_result["message"]}
3296+
return {"status": "success", "message": "Job completed"}
3297+
return {"status": str(job.state).lower(), "message": "Job running"}
32553298

32563299
async def get_job_output(
32573300
self,

airflow/providers/google/cloud/operators/bigquery.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def execute(self, context: Context):
313313
conn_id=self.gcp_conn_id,
314314
job_id=job.job_id,
315315
project_id=hook.project_id,
316+
location=self.location or hook.location,
316317
poll_interval=self.poll_interval,
317318
impersonation_chain=self.impersonation_chain,
318319
),
@@ -438,6 +439,7 @@ def execute(self, context: Context) -> None: # type: ignore[override]
438439
conn_id=self.gcp_conn_id,
439440
job_id=job.job_id,
440441
project_id=hook.project_id,
442+
location=self.location or hook.location,
441443
sql=self.sql,
442444
pass_value=self.pass_value,
443445
tolerance=self.tol,
@@ -594,6 +596,7 @@ def execute(self, context: Context):
594596
second_job_id=job_2.job_id,
595597
project_id=hook.project_id,
596598
table=self.table,
599+
location=self.location or hook.location,
597600
metrics_thresholds=self.metrics_thresholds,
598601
date_filter_column=self.date_filter_column,
599602
days_back=self.days_back,
@@ -1068,6 +1071,7 @@ def execute(self, context: Context):
10681071
dataset_id=self.dataset_id,
10691072
table_id=self.table_id,
10701073
project_id=self.job_project_id or hook.project_id,
1074+
location=self.location or hook.location,
10711075
poll_interval=self.poll_interval,
10721076
as_dict=self.as_dict,
10731077
impersonation_chain=self.impersonation_chain,
@@ -2876,6 +2880,7 @@ def execute(self, context: Any):
28762880
conn_id=self.gcp_conn_id,
28772881
job_id=self.job_id,
28782882
project_id=self.project_id,
2883+
location=self.location or hook.location,
28792884
poll_interval=self.poll_interval,
28802885
impersonation_chain=self.impersonation_chain,
28812886
),

airflow/providers/google/cloud/transfers/bigquery_to_gcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def execute(self, context: Context):
261261
conn_id=self.gcp_conn_id,
262262
job_id=self._job_id,
263263
project_id=self.project_id or self.hook.project_id,
264+
location=self.location or self.hook.location,
264265
impersonation_chain=self.impersonation_chain,
265266
),
266267
method_name="execute_complete",

airflow/providers/google/cloud/transfers/gcs_to_bigquery.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def execute(self, context: Context):
435435
conn_id=self.gcp_conn_id,
436436
job_id=self.job_id,
437437
project_id=self.project_id or self.hook.project_id,
438+
location=self.location or self.hook.location,
438439
impersonation_chain=self.impersonation_chain,
439440
),
440441
method_name="execute_complete",

airflow/providers/google/cloud/triggers/bigquery.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
3333
:param conn_id: Reference to google cloud connection id
3434
:param job_id: The ID of the job. It will be suffixed with hash of job configuration
3535
:param project_id: Google Cloud Project where the job is running
36+
:param location: The dataset location.
3637
:param dataset_id: The dataset ID of the requested table. (templated)
3738
:param table_id: The table ID of the requested table. (templated)
3839
:param poll_interval: polling period in seconds to check for the status. (templated)
@@ -51,6 +52,7 @@ def __init__(
5152
conn_id: str,
5253
job_id: str | None,
5354
project_id: str | None,
55+
location: str | None,
5456
dataset_id: str | None = None,
5557
table_id: str | None = None,
5658
poll_interval: float = 4.0,
@@ -63,6 +65,7 @@ def __init__(
6365
self._job_conn = None
6466
self.dataset_id = dataset_id
6567
self.project_id = project_id
68+
self.location = location
6669
self.table_id = table_id
6770
self.poll_interval = poll_interval
6871
self.impersonation_chain = impersonation_chain
@@ -76,6 +79,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
7679
"job_id": self.job_id,
7780
"dataset_id": self.dataset_id,
7881
"project_id": self.project_id,
82+
"location": self.location,
7983
"table_id": self.table_id,
8084
"poll_interval": self.poll_interval,
8185
"impersonation_chain": self.impersonation_chain,
@@ -87,7 +91,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
8791
hook = self._get_async_hook()
8892
try:
8993
while True:
90-
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
94+
job_status = await hook.get_job_status(
95+
job_id=self.job_id, project_id=self.project_id, location=self.location
96+
)
9197
if job_status["status"] == "success":
9298
yield TriggerEvent(
9399
{
@@ -127,6 +133,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
127133
"job_id": self.job_id,
128134
"dataset_id": self.dataset_id,
129135
"project_id": self.project_id,
136+
"location": self.location,
130137
"table_id": self.table_id,
131138
"poll_interval": self.poll_interval,
132139
"impersonation_chain": self.impersonation_chain,
@@ -201,6 +208,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
201208
"job_id": self.job_id,
202209
"dataset_id": self.dataset_id,
203210
"project_id": self.project_id,
211+
"location": self.location,
204212
"table_id": self.table_id,
205213
"poll_interval": self.poll_interval,
206214
"impersonation_chain": self.impersonation_chain,
@@ -253,6 +261,7 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
253261
:param dataset_id: The dataset ID of the requested table. (templated)
254262
:param table: table name
255263
:param metrics_thresholds: dictionary of ratios indexed by metrics
264+
:param location: The dataset location.
256265
:param date_filter_column: column name. (templated)
257266
:param days_back: number of days between ds and the ds we want to check against. (templated)
258267
:param ratio_formula: ration formula. (templated)
@@ -277,6 +286,7 @@ def __init__(
277286
project_id: str | None,
278287
table: str,
279288
metrics_thresholds: dict[str, int],
289+
location: str | None = None,
280290
date_filter_column: str | None = "ds",
281291
days_back: SupportsAbs[int] = -7,
282292
ratio_formula: str = "max_over_min",
@@ -290,6 +300,7 @@ def __init__(
290300
conn_id=conn_id,
291301
job_id=first_job_id,
292302
project_id=project_id,
303+
location=location,
293304
dataset_id=dataset_id,
294305
table_id=table_id,
295306
poll_interval=poll_interval,
@@ -317,6 +328,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
317328
"project_id": self.project_id,
318329
"table": self.table,
319330
"metrics_thresholds": self.metrics_thresholds,
331+
"location": self.location,
320332
"date_filter_column": self.date_filter_column,
321333
"days_back": self.days_back,
322334
"ratio_formula": self.ratio_formula,
@@ -414,6 +426,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
414426
:param tolerance: certain metrics for tolerance. (templated)
415427
:param dataset_id: The dataset ID of the requested table. (templated)
416428
:param table_id: The table ID of the requested table. (templated)
429+
:param location: The dataset location
417430
:param poll_interval: polling period in seconds to check for the status. (templated)
418431
:param impersonation_chain: Optional service account to impersonate using short-term
419432
credentials, or chained list of accounts required to get the access_token
@@ -435,6 +448,7 @@ def __init__(
435448
tolerance: Any = None,
436449
dataset_id: str | None = None,
437450
table_id: str | None = None,
451+
location: str | None = None,
438452
poll_interval: float = 4.0,
439453
impersonation_chain: str | Sequence[str] | None = None,
440454
):
@@ -444,6 +458,7 @@ def __init__(
444458
project_id=project_id,
445459
dataset_id=dataset_id,
446460
table_id=table_id,
461+
location=location,
447462
poll_interval=poll_interval,
448463
impersonation_chain=impersonation_chain,
449464
)
@@ -464,6 +479,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
464479
"sql": self.sql,
465480
"table_id": self.table_id,
466481
"tolerance": self.tolerance,
482+
"location": self.location,
467483
"poll_interval": self.poll_interval,
468484
"impersonation_chain": self.impersonation_chain,
469485
},

tests/providers/google/cloud/hooks/test_bigquery.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2155,23 +2155,18 @@ async def test_get_job_instance(self, mock_session, mock_auth_default):
21552155
assert isinstance(result, Job)
21562156

21572157
@pytest.mark.parametrize(
2158-
"job_status, expected",
2158+
"job_state, error_result, expected",
21592159
[
2160-
({"status": {"state": "DONE"}}, {"status": "success", "message": "Job completed"}),
2161-
(
2162-
{"status": {"state": "DONE", "errorResult": {"message": "Timeout"}}},
2163-
{"status": "error", "message": "Timeout"},
2164-
),
2165-
({"status": {"state": "running"}}, {"status": "running", "message": "Job running"}),
2160+
("DONE", None, {"status": "success", "message": "Job completed"}),
2161+
("DONE", {"message": "Timeout"}, {"status": "error", "message": "Timeout"}),
2162+
("RUNNING", None, {"status": "running", "message": "Job running"}),
21662163
],
21672164
)
21682165
@pytest.mark.asyncio
2169-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
2170-
async def test_get_job_status(self, mock_job_instance, job_status, expected):
2166+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
2167+
async def test_get_job_status(self, mock_get_job, job_state, error_result, expected):
21712168
hook = BigQueryAsyncHook()
2172-
mock_job_client = AsyncMock(Job)
2173-
mock_job_instance.return_value = mock_job_client
2174-
mock_job_instance.return_value.get_job.return_value = job_status
2169+
mock_get_job.return_value = mock.MagicMock(state=job_state, error_result=error_result)
21752170
resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
21762171
assert resp == expected
21772172

0 commit comments

Comments
 (0)