Skip to content

Commit 5a0019c

Browse files
Oleg Kachurolegkachur-e
authored andcommitted
Add Google Cloud VertexAI and Translate datasets import data verification
For the: - Google Cloud VertexAI datasets. - Google Cloud Trasnalation native model datasets.
1 parent 97c0b40 commit 5a0019c

6 files changed

Lines changed: 110 additions & 9 deletions

File tree

providers/google/src/airflow/providers/google/cloud/hooks/translate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def get_dataset(
429429
project_id: str,
430430
location: str,
431431
retry: Retry | _MethodDefault = DEFAULT,
432-
timeout: float | _MethodDefault = DEFAULT,
432+
timeout: float | None | _MethodDefault = DEFAULT,
433433
metadata: Sequence[tuple[str, str]] = (),
434434
) -> automl_translation.Dataset:
435435
"""

providers/google/src/airflow/providers/google/cloud/links/translate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class TranslationNativeDatasetLink(BaseGoogleLink):
149149
"""
150150

151151
name = "Translation Native Dataset"
152-
key = "translation_naive_dataset"
152+
key = "translation_native_dataset"
153153
format_str = TRANSLATION_NATIVE_DATASET_LINK
154154

155155

providers/google/src/airflow/providers/google/cloud/operators/translate.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
TranslationNativeDatasetLink,
3838
)
3939
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
40+
from airflow.providers.google.cloud.operators.vertex_ai.dataset import DatasetImportDataResultsCheckHelper
4041
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
4142

4243
if TYPE_CHECKING:
@@ -575,7 +576,7 @@ def execute(self, context: Context):
575576
return result_ids
576577

577578

578-
class TranslateImportDataOperator(GoogleCloudBaseOperator):
579+
class TranslateImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
579580
"""
580581
Import data to the translation dataset.
581582
@@ -602,6 +603,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
602603
If set as a sequence, the identities from the list must grant
603604
Service Account Token Creator IAM role to the directly preceding identity, with first
604605
account from the list granting this role to the originating account (templated).
606+
:param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
605607
"""
606608

607609
template_fields: Sequence[str] = (
@@ -627,6 +629,7 @@ def __init__(
627629
retry: Retry | _MethodDefault = DEFAULT,
628630
gcp_conn_id: str = "google_cloud_default",
629631
impersonation_chain: str | Sequence[str] | None = None,
632+
raise_for_empty_result: bool = False,
630633
**kwargs,
631634
) -> None:
632635
super().__init__(**kwargs)
@@ -639,9 +642,21 @@ def __init__(
639642
self.retry = retry
640643
self.gcp_conn_id = gcp_conn_id
641644
self.impersonation_chain = impersonation_chain
645+
self.raise_for_empty_result = raise_for_empty_result
642646

643647
def execute(self, context: Context):
644648
hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
649+
initial_dataset_size = self._get_number_of_ds_items(
650+
dataset=hook.get_dataset(
651+
dataset_id=self.dataset_id,
652+
project_id=self.project_id,
653+
location=self.location,
654+
retry=self.retry,
655+
timeout=self.timeout,
656+
metadata=self.metadata,
657+
),
658+
total_key_name="example_count",
659+
)
645660
self.log.info("Importing data to dataset...")
646661
operation = hook.import_dataset_data(
647662
dataset_id=self.dataset_id,
@@ -660,7 +675,22 @@ def execute(self, context: Context):
660675
location=self.location,
661676
)
662677
hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
678+
679+
result_dataset_size = self._get_number_of_ds_items(
680+
dataset=hook.get_dataset(
681+
dataset_id=self.dataset_id,
682+
project_id=self.project_id,
683+
location=self.location,
684+
retry=self.retry,
685+
timeout=self.timeout,
686+
metadata=self.metadata,
687+
),
688+
total_key_name="example_count",
689+
)
690+
if self.raise_for_empty_result:
691+
self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
663692
self.log.info("Importing data finished!")
693+
return {"total_imported": int(result_dataset_size) - int(initial_dataset_size)}
664694

665695

666696
class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):

providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
2727
from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig
2828

29+
from airflow.exceptions import AirflowException
2930
from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
3031
from airflow.providers.google.cloud.links.vertex_ai import VertexAIDatasetLink, VertexAIDatasetListLink
3132
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
@@ -335,7 +336,21 @@ def execute(self, context: Context):
335336
self.log.info("Export was done successfully")
336337

337338

338-
class ImportDataOperator(GoogleCloudBaseOperator):
339+
class DatasetImportDataResultsCheckHelper:
340+
"""Helper utils to verify import dataset data results."""
341+
342+
@staticmethod
343+
def _get_number_of_ds_items(dataset, total_key_name):
344+
number_of_items = type(dataset).to_dict(dataset).get(total_key_name, 0)
345+
return number_of_items
346+
347+
@staticmethod
348+
def _raise_for_empty_import_result(dataset_id, initial_size, size_after_import):
349+
if int(size_after_import) - int(initial_size) <= 0:
350+
raise AirflowException(f"Empty results of data import for the dataset_id {dataset_id}.")
351+
352+
353+
class ImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
339354
"""
340355
Imports data into a Dataset.
341356
@@ -356,6 +371,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
356371
If set as a sequence, the identities from the list must grant
357372
Service Account Token Creator IAM role to the directly preceding identity, with first
358373
account from the list granting this role to the originating account (templated).
374+
:param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
359375
"""
360376

361377
template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
@@ -372,6 +388,7 @@ def __init__(
372388
metadata: Sequence[tuple[str, str]] = (),
373389
gcp_conn_id: str = "google_cloud_default",
374390
impersonation_chain: str | Sequence[str] | None = None,
391+
raise_for_empty_result: bool = False,
375392
**kwargs,
376393
) -> None:
377394
super().__init__(**kwargs)
@@ -384,13 +401,24 @@ def __init__(
384401
self.metadata = metadata
385402
self.gcp_conn_id = gcp_conn_id
386403
self.impersonation_chain = impersonation_chain
404+
self.raise_for_empty_result = raise_for_empty_result
387405

388406
def execute(self, context: Context):
389407
hook = DatasetHook(
390408
gcp_conn_id=self.gcp_conn_id,
391409
impersonation_chain=self.impersonation_chain,
392410
)
393-
411+
initial_dataset_size = self._get_number_of_ds_items(
412+
dataset=hook.get_dataset(
413+
dataset_id=self.dataset_id,
414+
project_id=self.project_id,
415+
region=self.region,
416+
retry=self.retry,
417+
timeout=self.timeout,
418+
metadata=self.metadata,
419+
),
420+
total_key_name="data_item_count",
421+
)
394422
self.log.info("Importing data: %s", self.dataset_id)
395423
operation = hook.import_data(
396424
project_id=self.project_id,
@@ -402,7 +430,21 @@ def execute(self, context: Context):
402430
metadata=self.metadata,
403431
)
404432
hook.wait_for_operation(timeout=self.timeout, operation=operation)
433+
result_dataset_size = self._get_number_of_ds_items(
434+
dataset=hook.get_dataset(
435+
dataset_id=self.dataset_id,
436+
project_id=self.project_id,
437+
region=self.region,
438+
retry=self.retry,
439+
timeout=self.timeout,
440+
metadata=self.metadata,
441+
),
442+
total_key_name="data_item_count",
443+
)
444+
if self.raise_for_empty_result:
445+
self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
405446
self.log.info("Import was done successfully")
447+
return {"total_data_items_imported": int(result_dataset_size) - int(initial_dataset_size)}
406448

407449

408450
class ListDatasetsOperator(GoogleCloudBaseOperator):

providers/google/tests/unit/google/cloud/operators/test_translate.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.api_core.gapic_v1.method import DEFAULT
2323
from google.cloud.translate_v3.types import (
2424
BatchTranslateDocumentResponse,
25+
Dataset,
2526
TranslateDocumentResponse,
2627
automl_translation,
2728
translation_service,
@@ -331,6 +332,19 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist):
331332
"input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": "import data gcs path"}}]
332333
}
333334
mock_hook.return_value.import_dataset_data.return_value = mock.MagicMock()
335+
336+
SAMPLE_DATASET = {
337+
"name": "sample_translation_dataset",
338+
"example_count": None,
339+
"source_language_code": "en",
340+
"target_language_code": "es",
341+
}
342+
INITIAL_DS_SIZE = 1
343+
FINAL_DS_SIZE = 101
344+
INITIAL_DS = {**SAMPLE_DATASET, "example_count": INITIAL_DS_SIZE}
345+
FINAL_DS = {**SAMPLE_DATASET, "example_count": FINAL_DS_SIZE}
346+
347+
mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS), Dataset(FINAL_DS)]
334348
op = TranslateImportDataOperator(
335349
task_id="task_id",
336350
dataset_id=DATASET_ID,
@@ -343,7 +357,7 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist):
343357
retry=DEFAULT,
344358
)
345359
context = mock.MagicMock()
346-
op.execute(context=context)
360+
res = op.execute(context=context)
347361
mock_hook.assert_called_once_with(
348362
gcp_conn_id=GCP_CONN_ID,
349363
impersonation_chain=IMPERSONATION_CHAIN,
@@ -363,6 +377,7 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist):
363377
location=LOCATION,
364378
project_id=PROJECT_ID,
365379
)
380+
assert res["total_imported"] == FINAL_DS_SIZE - INITIAL_DS_SIZE
366381

367382

368383
class TestTranslateDeleteData:

providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from google.api_core.gapic_v1.method import DEFAULT
2828
from google.api_core.retry import Retry
29+
from google.cloud.aiplatform_v1.types.dataset import Dataset
2930

3031
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred
3132
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
@@ -1362,9 +1363,8 @@ def test_execute(self, mock_hook, to_dict_mock):
13621363

13631364

13641365
class TestVertexAIImportDataOperator:
1365-
@mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
13661366
@mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
1367-
def test_execute(self, mock_hook, to_dict_mock):
1367+
def test_execute(self, mock_hook):
13681368
op = ImportDataOperator(
13691369
task_id=TASK_ID,
13701370
gcp_conn_id=GCP_CONN_ID,
@@ -1377,7 +1377,20 @@ def test_execute(self, mock_hook, to_dict_mock):
13771377
timeout=TIMEOUT,
13781378
metadata=METADATA,
13791379
)
1380-
op.execute(context={})
1380+
SAMPLE_DATASET = {
1381+
"name": "sample_translation_dataset",
1382+
"display_name": "VertexAI dataset",
1383+
"data_item_count": None,
1384+
}
1385+
INITIAL_DS_SIZE = 1
1386+
FINAL_DS_SIZE = 101
1387+
INITIAL_DS = {**SAMPLE_DATASET, "data_item_count": INITIAL_DS_SIZE}
1388+
FINAL_DS = {**SAMPLE_DATASET, "data_item_count": FINAL_DS_SIZE}
1389+
1390+
mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS), Dataset(FINAL_DS)]
1391+
1392+
res = op.execute(context={})
1393+
13811394
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
13821395
mock_hook.return_value.import_data.assert_called_once_with(
13831396
region=GCP_LOCATION,
@@ -1388,6 +1401,7 @@ def test_execute(self, mock_hook, to_dict_mock):
13881401
timeout=TIMEOUT,
13891402
metadata=METADATA,
13901403
)
1404+
assert res["total_data_items_imported"] == FINAL_DS_SIZE - INITIAL_DS_SIZE
13911405

13921406

13931407
class TestVertexAIListDatasetsOperator:

0 commit comments

Comments
 (0)