Skip to content

Commit 7910e95

Browse files
committed
Enable DatabricksJobRunLink for Databricks plugin, skip provide_session usage in Airflow3
1 parent a8588fb commit 7910e95

4 files changed

Lines changed: 375 additions & 88 deletions

File tree

providers/databricks/src/airflow/providers/databricks/operators/databricks.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from airflow.providers.databricks.plugins.databricks_workflow import (
4343
WorkflowJobRepairSingleTaskLink,
4444
WorkflowJobRunLink,
45+
store_databricks_job_run_link,
4546
)
4647
from airflow.providers.databricks.triggers.databricks import (
4748
DatabricksExecutionTrigger,
@@ -1214,10 +1215,16 @@ def __init__(
12141215
super().__init__(**kwargs)
12151216

12161217
if self._databricks_workflow_task_group is not None:
1217-
self.operator_extra_links = (
1218-
WorkflowJobRunLink(),
1219-
WorkflowJobRepairSingleTaskLink(),
1220-
)
1218+
# Conditionally set operator_extra_links based on Airflow version. In Airflow 3, only show the job run link.
1219+
# In Airflow 2, show the job run link and the repair link.
1220+
# TODO: Once we expand the plugin functionality in Airflow 3.1, this can be re-evaluated on how to handle the repair link.
1221+
if AIRFLOW_V_3_0_PLUS:
1222+
self.operator_extra_links = (WorkflowJobRunLink(),)
1223+
else:
1224+
self.operator_extra_links = (
1225+
WorkflowJobRunLink(),
1226+
WorkflowJobRepairSingleTaskLink(),
1227+
)
12211228
else:
12221229
# Databricks does not support repair for non-workflow tasks, hence do not show the repair link.
12231230
self.operator_extra_links = (DatabricksJobRunLink(),)
@@ -1427,6 +1434,15 @@ def execute(self, context: Context) -> None:
14271434
)
14281435
self.databricks_run_id = workflow_run_metadata.run_id
14291436
self.databricks_conn_id = workflow_run_metadata.conn_id
1437+
1438+
# Store operator links in XCom for Airflow 3 compatibility
1439+
if AIRFLOW_V_3_0_PLUS:
1440+
# Store the job run link
1441+
store_databricks_job_run_link(
1442+
context=context,
1443+
metadata=workflow_run_metadata,
1444+
logger=self.log,
1445+
)
14301446
else:
14311447
self._launch_job(context=context)
14321448
if self.wait_for_termination:

providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
from airflow.providers.databricks.plugins.databricks_workflow import (
3232
WorkflowJobRepairAllFailedLink,
3333
WorkflowJobRunLink,
34+
store_databricks_job_run_link,
3435
)
36+
from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
3537
from airflow.utils.task_group import TaskGroup
3638

3739
if TYPE_CHECKING:
@@ -92,9 +94,18 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
9294
populated after instantiation using the `add_task` method.
9395
"""
9496

95-
operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink())
9697
template_fields = ("notebook_params", "job_clusters")
9798
caller = "_CreateDatabricksWorkflowOperator"
99+
# Conditionally set operator_extra_links based on Airflow version
100+
if AIRFLOW_V_3_0_PLUS:
101+
# In Airflow 3, disable "Repair All Failed Tasks" since we can't pre-determine failed tasks
102+
operator_extra_links = (WorkflowJobRunLink(),)
103+
else:
104+
# In Airflow 2.x, keep both links
105+
operator_extra_links = ( # type: ignore[assignment]
106+
WorkflowJobRunLink(),
107+
WorkflowJobRepairAllFailedLink(),
108+
)
98109

99110
def __init__(
100111
self,
@@ -219,6 +230,15 @@ def execute(self, context: Context) -> Any:
219230
run_id,
220231
)
221232

233+
# Store operator links in XCom for Airflow 3 compatibility
234+
if AIRFLOW_V_3_0_PLUS:
235+
# Store the job run link
236+
store_databricks_job_run_link(
237+
context=context,
238+
metadata=self.workflow_run_metadata,
239+
logger=self.log,
240+
)
241+
222242
return {
223243
"conn_id": self.databricks_conn_id,
224244
"job_id": job_id,

providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py

Lines changed: 118 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
else:
4141
from airflow.www import auth # type: ignore
4242
from airflow.utils.log.logging_mixin import LoggingMixin
43-
from airflow.utils.session import NEW_SESSION, provide_session
4443
from airflow.utils.state import TaskInstanceState
4544
from airflow.utils.task_group import TaskGroup
4645

@@ -49,6 +48,7 @@
4948

5049
from airflow.models import BaseOperator
5150
from airflow.providers.databricks.operators.databricks import DatabricksTaskBaseOperator
51+
from airflow.utils.context import Context
5252

5353
if AIRFLOW_V_3_0_PLUS:
5454
from airflow.sdk import BaseOperatorLink
@@ -93,32 +93,56 @@ def get_databricks_task_ids(
9393
return task_ids
9494

9595

96-
@provide_session
97-
def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun:
98-
"""
99-
Retrieve the DagRun object associated with the specified DAG and run_id.
96+
# TODO: Need to re-think on how to support the currently unavailable repair functionality in Airflow 3. Probably a
97+
# good time to re-evaluate this would be once the plugin functionality is expanded in Airflow 3.1.
98+
if not AIRFLOW_V_3_0_PLUS:
99+
from airflow.utils.session import NEW_SESSION, provide_session
100100

101-
:param dag: The DAG object associated with the DagRun to retrieve.
102-
:param run_id: The run_id associated with the DagRun to retrieve.
103-
:param session: The SQLAlchemy session to use for the query. If None, uses the default session.
104-
:return: The DagRun object associated with the specified DAG and run_id.
105-
"""
106-
if not session:
107-
raise AirflowException("Session not provided.")
101+
@provide_session
102+
def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun:
103+
"""
104+
Retrieve the DagRun object associated with the specified DAG and run_id.
108105
109-
return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first()
106+
:param dag: The DAG object associated with the DagRun to retrieve.
107+
:param run_id: The run_id associated with the DagRun to retrieve.
108+
:param session: The SQLAlchemy session to use for the query. If None, uses the default session.
109+
:return: The DagRun object associated with the specified DAG and run_id.
110+
"""
111+
if not session:
112+
raise AirflowException("Session not provided.")
110113

114+
return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first()
111115

112-
@provide_session
113-
def _clear_task_instances(
114-
dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None
115-
) -> None:
116-
dag_bag = DagBag(read_dags_from_db=True)
117-
dag = dag_bag.get_dag(dag_id)
118-
log.debug("task_ids %s to clear", str(task_ids))
119-
dr: DagRun = _get_dagrun(dag, run_id, session=session)
120-
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
121-
clear_task_instances(tis_to_clear, session)
116+
@provide_session
117+
def _clear_task_instances(
118+
dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None
119+
) -> None:
120+
dag_bag = DagBag(read_dags_from_db=True)
121+
dag = dag_bag.get_dag(dag_id)
122+
log.debug("task_ids %s to clear", str(task_ids))
123+
dr: DagRun = _get_dagrun(dag, run_id, session=session)
124+
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
125+
clear_task_instances(tis_to_clear, session)
126+
127+
@provide_session
128+
def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance:
129+
dag_id = operator.dag.dag_id
130+
if hasattr(DagRun, "execution_date"): # Airflow 2.x.
131+
dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg]
132+
else:
133+
dag_run = DagRun.find(dag_id, logical_date=dttm)[0]
134+
ti = (
135+
session.query(TaskInstance)
136+
.filter(
137+
TaskInstance.dag_id == dag_id,
138+
TaskInstance.run_id == dag_run.run_id,
139+
TaskInstance.task_id == operator.task_id,
140+
)
141+
.one_or_none()
142+
)
143+
if not ti:
144+
raise TaskInstanceNotFound("Task instance not found")
145+
return ti
122146

123147

124148
def _repair_task(
@@ -201,27 +225,6 @@ def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> Tas
201225
return current_task_key
202226

203227

204-
@provide_session
205-
def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance:
206-
dag_id = operator.dag.dag_id
207-
if hasattr(DagRun, "execution_date"): # Airflow 2.x.
208-
dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg]
209-
else:
210-
dag_run = DagRun.find(dag_id, logical_date=dttm)[0]
211-
ti = (
212-
session.query(TaskInstance)
213-
.filter(
214-
TaskInstance.dag_id == dag_id,
215-
TaskInstance.run_id == dag_run.run_id,
216-
TaskInstance.task_id == operator.task_id,
217-
)
218-
.one_or_none()
219-
)
220-
if not ti:
221-
raise TaskInstanceNotFound("Task instance not found")
222-
return ti
223-
224-
225228
def get_xcom_result(
226229
ti_key: TaskInstanceKey,
227230
key: str,
@@ -240,13 +243,41 @@ class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin):
240243

241244
name = "See Databricks Job Run"
242245

246+
@property
247+
def xcom_key(self) -> str:
248+
"""XCom key where the link is stored during task execution."""
249+
return "databricks_job_run_link"
250+
243251
def get_link(
244252
self,
245253
operator: BaseOperator,
246254
dttm=None,
247255
*,
248256
ti_key: TaskInstanceKey | None = None,
249257
) -> str:
258+
if AIRFLOW_V_3_0_PLUS:
259+
# Use public XCom API to get the pre-computed link
260+
try:
261+
link = XCom.get_value(
262+
ti_key=ti_key,
263+
key=self.xcom_key,
264+
)
265+
return link if link else ""
266+
except Exception as e:
267+
self.log.warning("Failed to retrieve Databricks job run link from XCom: %s", e)
268+
return ""
269+
else:
270+
# Airflow 2.x - keep original implementation
271+
return self._get_link_legacy(operator, dttm, ti_key=ti_key)
272+
273+
def _get_link_legacy(
274+
self,
275+
operator: BaseOperator,
276+
dttm=None,
277+
*,
278+
ti_key: TaskInstanceKey | None = None,
279+
) -> str:
280+
"""Legacy implementation for Airflow 2.x."""
250281
if not ti_key:
251282
ti = get_task_instance(operator, dttm)
252283
ti_key = ti.key
@@ -269,6 +300,30 @@ def get_link(
269300
return f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}"
270301

271302

303+
def store_databricks_job_run_link(
304+
context: Context,
305+
metadata: Any,
306+
logger: logging.Logger,
307+
) -> None:
308+
"""
309+
Store the Databricks job run link in XCom during task execution.
310+
311+
This should be called by Databricks operators during their execution.
312+
"""
313+
if not AIRFLOW_V_3_0_PLUS:
314+
return # Only needed for Airflow 3
315+
316+
try:
317+
hook = DatabricksHook(metadata.conn_id)
318+
link = f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}"
319+
320+
# Store the link in XCom for the UI to retrieve as extra link
321+
context["ti"].xcom_push(key="databricks_job_run_link", value=link)
322+
logger.info("Stored Databricks job run link in XCom: %s", link)
323+
except Exception as e:
324+
logger.warning("Failed to store Databricks job run link: %s", e)
325+
326+
272327
class WorkflowJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin):
273328
"""Constructs a link to send a request to repair all failed tasks in the Databricks workflow."""
274329

@@ -455,13 +510,6 @@ def _get_return_url(dag_id: str, run_id: str) -> str:
455510
return url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id)
456511

457512

458-
repair_databricks_view = RepairDatabricksTasks()
459-
460-
repair_databricks_package = {
461-
"view": repair_databricks_view,
462-
}
463-
464-
465513
class DatabricksWorkflowPlugin(AirflowPlugin):
466514
"""
467515
Databricks Workflows plugin for Airflow.
@@ -472,9 +520,22 @@ class DatabricksWorkflowPlugin(AirflowPlugin):
472520
"""
473521

474522
name = "databricks_workflow"
475-
operator_extra_links = [
476-
WorkflowJobRepairAllFailedLink(),
477-
WorkflowJobRepairSingleTaskLink(),
478-
WorkflowJobRunLink(),
479-
]
480-
appbuilder_views = [repair_databricks_package]
523+
524+
# Conditionally set operator_extra_links based on Airflow version
525+
if AIRFLOW_V_3_0_PLUS:
526+
# In Airflow 3, disable the links for repair functionality until it is figured out it can be supported
527+
operator_extra_links = [
528+
WorkflowJobRunLink(),
529+
]
530+
else:
531+
# In Airflow 2.x, keep all links including repair all failed tasks
532+
operator_extra_links = [
533+
WorkflowJobRepairAllFailedLink(),
534+
WorkflowJobRepairSingleTaskLink(),
535+
WorkflowJobRunLink(),
536+
]
537+
repair_databricks_view = RepairDatabricksTasks()
538+
repair_databricks_package = {
539+
"view": repair_databricks_view,
540+
}
541+
appbuilder_views = [repair_databricks_package]

0 commit comments

Comments
 (0)