Skip to content

Commit 6adae0b

Browse files
authored
openlineage: add support for hook lineage for Object Store (#40829)
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
1 parent b713b30 commit 6adae0b

7 files changed

Lines changed: 236 additions & 21 deletions

File tree

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,14 +692,14 @@ repos:
692692
^airflow/providers/.*\.py$
693693
exclude: ^.*/.*_vendor/
694694
- id: check-get-lineage-collector-providers
695-
language: pygrep
695+
language: python
696696
name: Check providers import hook lineage code from compat
697697
description: Make sure you import from airflow.provider.common.compat.lineage.hook instead of
698698
airflow.lineage.hook.
699-
entry: "airflow\\.lineage\\.hook"
700-
pass_filenames: true
699+
entry: ./scripts/ci/pre_commit/check_get_lineage_collector_providers.py
701700
files: ^airflow/providers/.*\.py$
702701
exclude: ^airflow/providers/common/compat/.*\.py$
702+
additional_dependencies: [ 'rich>=12.4.4' ]
703703
- id: check-decorated-operator-implements-custom-name
704704
name: Check @task decorator implements custom_operator_name
705705
language: python

airflow/io/path.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from airflow.io.store import attach
3131
from airflow.io.utils.stat import stat_result
32+
from airflow.lineage.hook import get_hook_lineage_collector
33+
from airflow.utils.log.logging_mixin import LoggingMixin
3234

3335
if typing.TYPE_CHECKING:
3436
from fsspec import AbstractFileSystem
@@ -39,6 +41,42 @@
3941
default = "file"
4042

4143

44+
class TrackingFileWrapper(LoggingMixin):
45+
"""Wrapper that tracks file operations to intercept lineage."""
46+
47+
def __init__(self, path: ObjectStoragePath, obj):
48+
super().__init__()
49+
self._path: ObjectStoragePath = path
50+
self._obj = obj
51+
52+
def __getattr__(self, name):
53+
attr = getattr(self._obj, name)
54+
if callable(attr):
55+
# If the attribute is a method, wrap it in another method to intercept the call
56+
def wrapper(*args, **kwargs):
57+
self.log.error("Calling method: %s", name)
58+
if name == "read":
59+
get_hook_lineage_collector().add_input_dataset(context=self._path, uri=str(self._path))
60+
elif name == "write":
61+
get_hook_lineage_collector().add_output_dataset(context=self._path, uri=str(self._path))
62+
result = attr(*args, **kwargs)
63+
return result
64+
65+
return wrapper
66+
return attr
67+
68+
def __getitem__(self, key):
69+
# Intercept item access
70+
return self._obj[key]
71+
72+
def __enter__(self):
73+
self._obj.__enter__()
74+
return self
75+
76+
def __exit__(self, exc_type, exc_val, exc_tb):
77+
self._obj.__exit__(exc_type, exc_val, exc_tb)
78+
79+
4280
class ObjectStoragePath(CloudPath):
4381
"""A path-like object for object storage."""
4482

@@ -121,7 +159,7 @@ def namespace(self) -> str:
121159
def open(self, mode="r", **kwargs):
122160
"""Open the file pointed to by this path."""
123161
kwargs.setdefault("block_size", kwargs.pop("buffering", None))
124-
return self.fs.open(self.path, mode=mode, **kwargs)
162+
return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs))
125163

126164
def stat(self) -> stat_result: # type: ignore[override]
127165
"""Call ``stat`` and return the result."""
@@ -276,6 +314,11 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
276314
if isinstance(dst, str):
277315
dst = ObjectStoragePath(dst)
278316

317+
if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file":
318+
# only emit this in "optimized" variants - else lineage will be captured by file writes/reads
319+
get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self))
320+
get_hook_lineage_collector().add_output_dataset(context=dst, uri=str(dst))
321+
279322
# same -> same
280323
if self.samestore(dst):
281324
self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs)
@@ -319,7 +362,6 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
319362
continue
320363

321364
src_obj._cp_file(dst)
322-
323365
return
324366

325367
# remote file -> remote dir
@@ -339,6 +381,8 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs)
339381
path = ObjectStoragePath(path)
340382

341383
if self.samestore(path):
384+
get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self))
385+
get_hook_lineage_collector().add_output_dataset(context=path, uri=str(path))
342386
return self.fs.move(self.path, path.path, recursive=recursive, **kwargs)
343387

344388
# non-local copy

airflow/lineage/hook.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
from typing import Union
20+
from typing import TYPE_CHECKING, Union
2121

2222
import attr
2323

2424
from airflow.datasets import Dataset
25-
from airflow.hooks.base import BaseHook
26-
from airflow.io.store import ObjectStore
2725
from airflow.providers_manager import ProvidersManager
2826
from airflow.utils.log.logging_mixin import LoggingMixin
2927

30-
# Store context what sent lineage.
31-
LineageContext = Union[BaseHook, ObjectStore]
28+
if TYPE_CHECKING:
29+
from airflow.hooks.base import BaseHook
30+
from airflow.io.path import ObjectStoragePath
31+
32+
# Store context what sent lineage.
33+
LineageContext = Union[BaseHook, ObjectStoragePath]
3234

3335
_hook_lineage_collector: HookLineageCollector | None = None
3436

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env python
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
from __future__ import annotations
20+
21+
import ast
22+
import sys
23+
from pathlib import Path
24+
from typing import NamedTuple
25+
26+
sys.path.insert(0, str(Path(__file__).parent.resolve()))
27+
from common_precommit_utils import console, initialize_breeze_precommit
28+
29+
initialize_breeze_precommit(__name__, __file__)
30+
31+
32+
class ImportTuple(NamedTuple):
33+
module: list[str]
34+
name: list[str]
35+
alias: str
36+
37+
38+
def get_toplevel_imports(path: str):
39+
with open(path) as fh:
40+
root = ast.parse(fh.read(), path)
41+
42+
for node in ast.iter_child_nodes(root):
43+
if isinstance(node, ast.Import):
44+
module: list[str] = node.names[0].name.split(".") if node.names else []
45+
elif isinstance(node, ast.ImportFrom) and node.module:
46+
module = node.module.split(".")
47+
else:
48+
continue
49+
50+
for n in node.names: # type: ignore[attr-defined]
51+
yield ImportTuple(module=module, name=n.name.split("."), alias=n.asname)
52+
53+
54+
errors: list[str] = []
55+
56+
57+
def main() -> int:
58+
for path in sys.argv[1:]:
59+
import_count = 0
60+
local_error_count = 0
61+
for imp in get_toplevel_imports(path):
62+
import_count += 1
63+
if len(imp.module) > 2:
64+
if imp.module[:3] == ["airflow", "lineage", "hook"]:
65+
local_error_count += 1
66+
errors.append(f"{path}: ({'.'.join(imp.module)})")
67+
console.print(f"[blue]{path}:[/] Import count: {import_count}, error_count {local_error_count}")
68+
if errors:
69+
console.print(
70+
"[red]Some providers files import directly top level from `airflow.lineage.hook` and they are not allowed.[/]\n"
71+
"Only TYPE_CHECKING imports from `airflow.lineage.hook` is allowed in providers."
72+
)
73+
console.print("Error summary:")
74+
for error in errors:
75+
console.print(error)
76+
return 1
77+
else:
78+
console.print("[green]All good!")
79+
return 0
80+
81+
82+
if __name__ == "__main__":
83+
sys.exit(main())

scripts/ci/pre_commit/check_tests_in_right_folders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python
2+
#
23
# Licensed to the Apache Software Foundation (ASF) under one
34
# or more contributor license agreements. See the NOTICE file
45
# distributed with this work for additional information

tests/io/test_path.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,11 @@ def test_relative_to(self):
267267
with pytest.raises(ValueError):
268268
o1.relative_to(o3)
269269

270-
def test_move_local(self):
271-
_from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
272-
_to = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
270+
def test_move_local(self, hook_lineage_collector):
271+
_from_path = f"file:///tmp/{str(uuid.uuid4())}"
272+
_to_path = f"file:///tmp/{str(uuid.uuid4())}"
273+
_from = ObjectStoragePath(_from_path)
274+
_to = ObjectStoragePath(_to_path)
273275

274276
_from.touch()
275277
_from.move(_to)
@@ -278,13 +280,19 @@ def test_move_local(self):
278280

279281
_to.unlink()
280282

281-
def test_move_remote(self):
283+
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
284+
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
285+
assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=_from_path)
286+
assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=_to_path)
287+
288+
def test_move_remote(self, hook_lineage_collector):
282289
attach("fakefs", fs=FakeRemoteFileSystem())
283290

284-
_from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
285-
print(_from)
286-
_to = ObjectStoragePath(f"fakefs:///tmp/{str(uuid.uuid4())}")
287-
print(_to)
291+
_from_path = f"file:///tmp/{str(uuid.uuid4())}"
292+
_to_path = f"fakefs:///tmp/{str(uuid.uuid4())}"
293+
294+
_from = ObjectStoragePath(_from_path)
295+
_to = ObjectStoragePath(_to_path)
288296

289297
_from.touch()
290298
_from.move(_to)
@@ -293,21 +301,28 @@ def test_move_remote(self):
293301

294302
_to.unlink()
295303

296-
def test_copy_remote_remote(self):
304+
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
305+
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
306+
assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=str(_from))
307+
assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=str(_to))
308+
309+
def test_copy_remote_remote(self, hook_lineage_collector):
297310
attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
298311
attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True))
299312

300313
dir_src = f"bucket1/{str(uuid.uuid4())}"
301314
dir_dst = f"bucket2/{str(uuid.uuid4())}"
302315
key = "foo/bar/baz.txt"
303316

304-
_from = ObjectStoragePath(f"ffs://{dir_src}")
317+
_from_path = f"ffs://{dir_src}"
318+
_from = ObjectStoragePath(_from_path)
305319
_from_file = _from / key
306320
_from_file.touch()
307321
assert _from.bucket == "bucket1"
308322
assert _from_file.exists()
309323

310-
_to = ObjectStoragePath(f"ffs2://{dir_dst}")
324+
_to_path = f"ffs2://{dir_dst}"
325+
_to = ObjectStoragePath(_to_path)
311326
_from.copy(_to)
312327

313328
assert _to.bucket == "bucket2"
@@ -319,6 +334,12 @@ def test_copy_remote_remote(self):
319334
_from.rmdir(recursive=True)
320335
_to.rmdir(recursive=True)
321336

337+
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
338+
assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=str(_from_file))
339+
340+
# Empty file - shutil.copyfileobj does nothing
341+
assert len(hook_lineage_collector.collected_datasets.outputs) == 0
342+
322343
def test_serde_objectstoragepath(self):
323344
path = "file:///bucket/key/part1/part2"
324345
o = ObjectStoragePath(path)

tests/io/test_wrapper.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import uuid
20+
from unittest.mock import patch
21+
22+
from airflow.datasets import Dataset
23+
from airflow.io.path import ObjectStoragePath
24+
25+
26+
@patch("airflow.providers_manager.ProvidersManager")
27+
def test_wrapper_catches_reads_writes(providers_manager, hook_lineage_collector):
28+
providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x)
29+
uri = f"file:///tmp/{str(uuid.uuid4())}"
30+
path = ObjectStoragePath(uri)
31+
file = path.open("w")
32+
file.write("aaa")
33+
file.close()
34+
35+
assert len(hook_lineage_collector.outputs) == 1
36+
assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)
37+
38+
file = path.open("r")
39+
file.read()
40+
file.close()
41+
42+
path.unlink(missing_ok=True)
43+
44+
assert len(hook_lineage_collector.inputs) == 1
45+
assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)
46+
47+
48+
@patch("airflow.providers_manager.ProvidersManager")
49+
def test_wrapper_works_with_contextmanager(providers_manager, hook_lineage_collector):
50+
providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x)
51+
uri = f"file:///tmp/{str(uuid.uuid4())}"
52+
path = ObjectStoragePath(uri)
53+
with path.open("w") as file:
54+
file.write("asdf")
55+
56+
assert len(hook_lineage_collector.outputs) == 1
57+
assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)
58+
59+
with path.open("r") as file:
60+
file.read()
61+
path.unlink(missing_ok=True)
62+
63+
assert len(hook_lineage_collector.inputs) == 1
64+
assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)

0 commit comments

Comments
 (0)