2929
3030from airflow .io .store import attach
3131from airflow .io .utils .stat import stat_result
32+ from airflow .lineage .hook import get_hook_lineage_collector
33+ from airflow .utils .log .logging_mixin import LoggingMixin
3234
3335if typing .TYPE_CHECKING :
3436 from fsspec import AbstractFileSystem
3941default = "file"
4042
4143
44+ class TrackingFileWrapper (LoggingMixin ):
45+ """Wrapper that tracks file operations to intercept lineage."""
46+
47+ def __init__ (self , path : ObjectStoragePath , obj ):
48+ super ().__init__ ()
49+ self ._path : ObjectStoragePath = path
50+ self ._obj = obj
51+
52+ def __getattr__ (self , name ):
53+ attr = getattr (self ._obj , name )
54+ if callable (attr ):
55+ # If the attribute is a method, wrap it in another method to intercept the call
56+ def wrapper (* args , ** kwargs ):
57+ self .log .error ("Calling method: %s" , name )
58+ if name == "read" :
59+ get_hook_lineage_collector ().add_input_dataset (context = self ._path , uri = str (self ._path ))
60+ elif name == "write" :
61+ get_hook_lineage_collector ().add_output_dataset (context = self ._path , uri = str (self ._path ))
62+ result = attr (* args , ** kwargs )
63+ return result
64+
65+ return wrapper
66+ return attr
67+
68+ def __getitem__ (self , key ):
69+ # Intercept item access
70+ return self ._obj [key ]
71+
72+ def __enter__ (self ):
73+ self ._obj .__enter__ ()
74+ return self
75+
76+ def __exit__ (self , exc_type , exc_val , exc_tb ):
77+ self ._obj .__exit__ (exc_type , exc_val , exc_tb )
78+
79+
4280class ObjectStoragePath (CloudPath ):
4381 """A path-like object for object storage."""
4482
@@ -121,7 +159,7 @@ def namespace(self) -> str:
121159 def open (self , mode = "r" , ** kwargs ):
122160 """Open the file pointed to by this path."""
123161 kwargs .setdefault ("block_size" , kwargs .pop ("buffering" , None ))
124- return self .fs .open (self .path , mode = mode , ** kwargs )
162+ return TrackingFileWrapper ( self , self .fs .open (self .path , mode = mode , ** kwargs ) )
125163
126164 def stat (self ) -> stat_result : # type: ignore[override]
127165 """Call ``stat`` and return the result."""
@@ -276,6 +314,11 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
276314 if isinstance (dst , str ):
277315 dst = ObjectStoragePath (dst )
278316
317+ if self .samestore (dst ) or self .protocol == "file" or dst .protocol == "file" :
318+ # only emit this in "optimized" variants - else lineage will be captured by file writes/reads
319+ get_hook_lineage_collector ().add_input_dataset (context = self , uri = str (self ))
320+ get_hook_lineage_collector ().add_output_dataset (context = dst , uri = str (dst ))
321+
279322 # same -> same
280323 if self .samestore (dst ):
281324 self .fs .copy (self .path , dst .path , recursive = recursive , ** kwargs )
@@ -319,7 +362,6 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
319362 continue
320363
321364 src_obj ._cp_file (dst )
322-
323365 return
324366
325367 # remote file -> remote dir
@@ -339,6 +381,8 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs)
339381 path = ObjectStoragePath (path )
340382
341383 if self .samestore (path ):
384+ get_hook_lineage_collector ().add_input_dataset (context = self , uri = str (self ))
385+ get_hook_lineage_collector ().add_output_dataset (context = path , uri = str (path ))
342386 return self .fs .move (self .path , path .path , recursive = recursive , ** kwargs )
343387
344388 # non-local copy
0 commit comments