4040else :
4141 from airflow .www import auth # type: ignore
4242from airflow .utils .log .logging_mixin import LoggingMixin
43- from airflow .utils .session import NEW_SESSION , provide_session
4443from airflow .utils .state import TaskInstanceState
4544from airflow .utils .task_group import TaskGroup
4645
4948
5049 from airflow .models import BaseOperator
5150 from airflow .providers .databricks .operators .databricks import DatabricksTaskBaseOperator
51+ from airflow .utils .context import Context
5252
5353if AIRFLOW_V_3_0_PLUS :
5454 from airflow .sdk import BaseOperatorLink
@@ -93,32 +93,56 @@ def get_databricks_task_ids(
9393 return task_ids
9494
9595
96- @ provide_session
97- def _get_dagrun ( dag : DAG , run_id : str , session : Session | None = None ) -> DagRun :
98- """
99- Retrieve the DagRun object associated with the specified DAG and run_id.
96+ # TODO: Need to re-think on how to support the currently unavailable repair functionality in Airflow 3. Probably a
97+ # good time to re-evaluate this would be once the plugin functionality is expanded in Airflow 3.1.
98+ if not AIRFLOW_V_3_0_PLUS :
99+ from airflow . utils . session import NEW_SESSION , provide_session
100100
101- :param dag: The DAG object associated with the DagRun to retrieve.
102- :param run_id: The run_id associated with the DagRun to retrieve.
103- :param session: The SQLAlchemy session to use for the query. If None, uses the default session.
104- :return: The DagRun object associated with the specified DAG and run_id.
105- """
106- if not session :
107- raise AirflowException ("Session not provided." )
101+ @provide_session
102+ def _get_dagrun (dag : DAG , run_id : str , session : Session | None = None ) -> DagRun :
103+ """
104+ Retrieve the DagRun object associated with the specified DAG and run_id.
108105
109- return session .query (DagRun ).filter (DagRun .dag_id == dag .dag_id , DagRun .run_id == run_id ).first ()
106+ :param dag: The DAG object associated with the DagRun to retrieve.
107+ :param run_id: The run_id associated with the DagRun to retrieve.
108+ :param session: The SQLAlchemy session to use for the query. If None, uses the default session.
109+ :return: The DagRun object associated with the specified DAG and run_id.
110+ """
111+ if not session :
112+ raise AirflowException ("Session not provided." )
110113
114+ return session .query (DagRun ).filter (DagRun .dag_id == dag .dag_id , DagRun .run_id == run_id ).first ()
111115
112- @provide_session
113- def _clear_task_instances (
114- dag_id : str , run_id : str , task_ids : list [str ], log : logging .Logger , session : Session | None = None
115- ) -> None :
116- dag_bag = DagBag (read_dags_from_db = True )
117- dag = dag_bag .get_dag (dag_id )
118- log .debug ("task_ids %s to clear" , str (task_ids ))
119- dr : DagRun = _get_dagrun (dag , run_id , session = session )
120- tis_to_clear = [ti for ti in dr .get_task_instances () if ti .databricks_task_key in task_ids ]
121- clear_task_instances (tis_to_clear , session )
116+ @provide_session
117+ def _clear_task_instances (
118+ dag_id : str , run_id : str , task_ids : list [str ], log : logging .Logger , session : Session | None = None
119+ ) -> None :
120+ dag_bag = DagBag (read_dags_from_db = True )
121+ dag = dag_bag .get_dag (dag_id )
122+ log .debug ("task_ids %s to clear" , str (task_ids ))
123+ dr : DagRun = _get_dagrun (dag , run_id , session = session )
124+ tis_to_clear = [ti for ti in dr .get_task_instances () if ti .databricks_task_key in task_ids ]
125+ clear_task_instances (tis_to_clear , session )
126+
127+ @provide_session
128+ def get_task_instance (operator : BaseOperator , dttm , session : Session = NEW_SESSION ) -> TaskInstance :
129+ dag_id = operator .dag .dag_id
130+ if hasattr (DagRun , "execution_date" ): # Airflow 2.x.
131+ dag_run = DagRun .find (dag_id , execution_date = dttm )[0 ] # type: ignore[call-arg]
132+ else :
133+ dag_run = DagRun .find (dag_id , logical_date = dttm )[0 ]
134+ ti = (
135+ session .query (TaskInstance )
136+ .filter (
137+ TaskInstance .dag_id == dag_id ,
138+ TaskInstance .run_id == dag_run .run_id ,
139+ TaskInstance .task_id == operator .task_id ,
140+ )
141+ .one_or_none ()
142+ )
143+ if not ti :
144+ raise TaskInstanceNotFound ("Task instance not found" )
145+ return ti
122146
123147
124148def _repair_task (
@@ -201,27 +225,6 @@ def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> Tas
201225 return current_task_key
202226
203227
204- @provide_session
205- def get_task_instance (operator : BaseOperator , dttm , session : Session = NEW_SESSION ) -> TaskInstance :
206- dag_id = operator .dag .dag_id
207- if hasattr (DagRun , "execution_date" ): # Airflow 2.x.
208- dag_run = DagRun .find (dag_id , execution_date = dttm )[0 ] # type: ignore[call-arg]
209- else :
210- dag_run = DagRun .find (dag_id , logical_date = dttm )[0 ]
211- ti = (
212- session .query (TaskInstance )
213- .filter (
214- TaskInstance .dag_id == dag_id ,
215- TaskInstance .run_id == dag_run .run_id ,
216- TaskInstance .task_id == operator .task_id ,
217- )
218- .one_or_none ()
219- )
220- if not ti :
221- raise TaskInstanceNotFound ("Task instance not found" )
222- return ti
223-
224-
225228def get_xcom_result (
226229 ti_key : TaskInstanceKey ,
227230 key : str ,
@@ -240,13 +243,41 @@ class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin):
240243
241244 name = "See Databricks Job Run"
242245
246+ @property
247+ def xcom_key (self ) -> str :
248+ """XCom key where the link is stored during task execution."""
249+ return "databricks_job_run_link"
250+
243251 def get_link (
244252 self ,
245253 operator : BaseOperator ,
246254 dttm = None ,
247255 * ,
248256 ti_key : TaskInstanceKey | None = None ,
249257 ) -> str :
258+ if AIRFLOW_V_3_0_PLUS :
259+ # Use public XCom API to get the pre-computed link
260+ try :
261+ link = XCom .get_value (
262+ ti_key = ti_key ,
263+ key = self .xcom_key ,
264+ )
265+ return link if link else ""
266+ except Exception as e :
267+ self .log .warning ("Failed to retrieve Databricks job run link from XCom: %s" , e )
268+ return ""
269+ else :
270+ # Airflow 2.x - keep original implementation
271+ return self ._get_link_legacy (operator , dttm , ti_key = ti_key )
272+
273+ def _get_link_legacy (
274+ self ,
275+ operator : BaseOperator ,
276+ dttm = None ,
277+ * ,
278+ ti_key : TaskInstanceKey | None = None ,
279+ ) -> str :
280+ """Legacy implementation for Airflow 2.x."""
250281 if not ti_key :
251282 ti = get_task_instance (operator , dttm )
252283 ti_key = ti .key
@@ -269,6 +300,30 @@ def get_link(
269300 return f"https://{ hook .host } /#job/{ metadata .job_id } /run/{ metadata .run_id } "
270301
271302
303+ def store_databricks_job_run_link (
304+ context : Context ,
305+ metadata : Any ,
306+ logger : logging .Logger ,
307+ ) -> None :
308+ """
309+ Store the Databricks job run link in XCom during task execution.
310+
311+ This should be called by Databricks operators during their execution.
312+ """
313+ if not AIRFLOW_V_3_0_PLUS :
314+ return # Only needed for Airflow 3
315+
316+ try :
317+ hook = DatabricksHook (metadata .conn_id )
318+ link = f"https://{ hook .host } /#job/{ metadata .job_id } /run/{ metadata .run_id } "
319+
320+ # Store the link in XCom for the UI to retrieve as extra link
321+ context ["ti" ].xcom_push (key = "databricks_job_run_link" , value = link )
322+ logger .info ("Stored Databricks job run link in XCom: %s" , link )
323+ except Exception as e :
324+ logger .warning ("Failed to store Databricks job run link: %s" , e )
325+
326+
272327class WorkflowJobRepairAllFailedLink (BaseOperatorLink , LoggingMixin ):
273328 """Constructs a link to send a request to repair all failed tasks in the Databricks workflow."""
274329
@@ -455,13 +510,6 @@ def _get_return_url(dag_id: str, run_id: str) -> str:
455510 return url_for ("Airflow.grid" , dag_id = dag_id , dag_run_id = run_id )
456511
457512
458- repair_databricks_view = RepairDatabricksTasks ()
459-
460- repair_databricks_package = {
461- "view" : repair_databricks_view ,
462- }
463-
464-
465513class DatabricksWorkflowPlugin (AirflowPlugin ):
466514 """
467515 Databricks Workflows plugin for Airflow.
@@ -472,9 +520,22 @@ class DatabricksWorkflowPlugin(AirflowPlugin):
472520 """
473521
474522 name = "databricks_workflow"
475- operator_extra_links = [
476- WorkflowJobRepairAllFailedLink (),
477- WorkflowJobRepairSingleTaskLink (),
478- WorkflowJobRunLink (),
479- ]
480- appbuilder_views = [repair_databricks_package ]
523+
524+ # Conditionally set operator_extra_links based on Airflow version
525+ if AIRFLOW_V_3_0_PLUS :
526+ # In Airflow 3, disable the links for repair functionality until it is figured out it can be supported
527+ operator_extra_links = [
528+ WorkflowJobRunLink (),
529+ ]
530+ else :
531+ # In Airflow 2.x, keep all links including repair all failed tasks
532+ operator_extra_links = [
533+ WorkflowJobRepairAllFailedLink (),
534+ WorkflowJobRepairSingleTaskLink (),
535+ WorkflowJobRunLink (),
536+ ]
537+ repair_databricks_view = RepairDatabricksTasks ()
538+ repair_databricks_package = {
539+ "view" : repair_databricks_view ,
540+ }
541+ appbuilder_views = [repair_databricks_package ]
0 commit comments