Skip to content

Commit 574102f

Browse files
gaurav7261GauravMGauravMGauravMTaragolis
authored
[FEAT] adds repair run functionality for databricks (#36601)
* [FEAT] adds repair run functionality for databricks * [FIX] addded latest repair run and test cases * [FIX] comma typo * [FIX] check for DatabricksRunNowOperator instance before doing repair run * [FIX] fixed static checks * [FIX] fixed static checks * Update airflow/providers/databricks/hooks/databricks.py Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is> * [FIX] type annotations * [FIX] change from log.warn to log.warning * Update airflow/providers/databricks/operators/databricks.py Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is> * [FIX] CI Static check --------- Co-authored-by: GauravM <gaurav@ip-192-168-0-100.ap-south-1.compute.internal> Co-authored-by: GauravM <gaurav@ip-192-168-0-101.ap-south-1.compute.internal> Co-authored-by: GauravM <gaurav@ip-10-20-1-171.ap-south-1.compute.internal> Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>
1 parent 449c814 commit 574102f

3 files changed

Lines changed: 103 additions & 2 deletions

File tree

airflow/providers/databricks/hooks/databricks.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,13 +519,24 @@ def delete_run(self, run_id: int) -> None:
519519
json = {"run_id": run_id}
520520
self._do_api_call(DELETE_RUN_ENDPOINT, json)
521521

522-
def repair_run(self, json: dict) -> None:
522+
def repair_run(self, json: dict) -> int:
523523
"""
524524
Re-run one or more tasks.
525525
526526
:param json: repair a job run.
527527
"""
528-
self._do_api_call(REPAIR_RUN_ENDPOINT, json)
528+
response = self._do_api_call(REPAIR_RUN_ENDPOINT, json)
529+
return response["repair_id"]
530+
531+
def get_latest_repair_id(self, run_id: int) -> int | None:
532+
"""Get latest repair id if any exist for run_id else None."""
533+
json = {"run_id": run_id, "include_history": True}
534+
response = self._do_api_call(GET_RUN_ENDPOINT, json)
535+
repair_history = response["repair_history"]
536+
if len(repair_history) == 1:
537+
return None
538+
else:
539+
return repair_history[-1]["id"]
529540

530541
def get_cluster_state(self, cluster_id: str) -> ClusterState:
531542
"""

airflow/providers/databricks/operators/databricks.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
8888
f"{operator.task_id} failed with terminal state: {run_state} "
8989
f"and with the error {run_state.state_message}"
9090
)
91+
if isinstance(operator, DatabricksRunNowOperator) and operator.repair_run:
92+
operator.repair_run = False
93+
log.warning(
94+
"%s but since repair run is set, repairing the run with all failed tasks",
95+
error_message,
96+
)
97+
98+
latest_repair_id = hook.get_latest_repair_id(operator.run_id)
99+
repair_json = {"run_id": operator.run_id, "rerun_all_failed_tasks": True}
100+
if latest_repair_id is not None:
101+
repair_json["latest_repair_id"] = latest_repair_id
102+
operator.json["latest_repair_id"] = hook.repair_run(operator, repair_json)
103+
_handle_databricks_operator_execution(operator, hook, log, context)
91104
raise AirflowException(error_message)
92105

93106
else:
@@ -623,6 +636,7 @@ class DatabricksRunNowOperator(BaseOperator):
623636
- ``jar_params``
624637
- ``spark_submit_params``
625638
- ``idempotency_token``
639+
- ``repair_run``
626640
627641
:param job_id: the job_id of the existing Databricks job.
628642
This field will be templated.
@@ -711,6 +725,7 @@ class DatabricksRunNowOperator(BaseOperator):
711725
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
712726
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
713727
:param deferrable: Run operator in the deferrable mode.
728+
:param repair_run: Repair the databricks run in case of failure, doesn't work in deferrable mode
714729
"""
715730

716731
# Used in airflow.models.BaseOperator
@@ -741,6 +756,7 @@ def __init__(
741756
do_xcom_push: bool = True,
742757
wait_for_termination: bool = True,
743758
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
759+
repair_run: bool = False,
744760
**kwargs,
745761
) -> None:
746762
"""Create a new ``DatabricksRunNowOperator``."""
@@ -753,6 +769,7 @@ def __init__(
753769
self.databricks_retry_args = databricks_retry_args
754770
self.wait_for_termination = wait_for_termination
755771
self.deferrable = deferrable
772+
self.repair_run = repair_run
756773

757774
if job_id is not None:
758775
self.json["job_id"] = job_id

tests/providers/databricks/hooks/test_databricks.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,79 @@ def test_repair_run(self, mock_requests):
683683
timeout=self.hook.timeout_seconds,
684684
)
685685

686+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
687+
def test_negative_get_latest_repair_id(self, mock_requests):
688+
mock_requests.codes.ok = 200
689+
mock_requests.get.return_value.json.return_value = {
690+
"job_id": JOB_ID,
691+
"run_id": RUN_ID,
692+
"state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"},
693+
"repair_history": [
694+
{
695+
"type": "ORIGINAL",
696+
"start_time": 1704528798059,
697+
"end_time": 1704529026679,
698+
"state": {
699+
"life_cycle_state": "RUNNING",
700+
"result_state": "RUNNING",
701+
"state_message": "dummy",
702+
"user_cancelled_or_timedout": "false",
703+
},
704+
"task_run_ids": [396529700633015, 1111270934390307],
705+
}
706+
],
707+
}
708+
latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)
709+
710+
assert latest_repair_id is None
711+
712+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
713+
def test_positive_get_latest_repair_id(self, mock_requests):
714+
mock_requests.codes.ok = 200
715+
mock_requests.get.return_value.json.return_value = {
716+
"job_id": JOB_ID,
717+
"run_id": RUN_ID,
718+
"state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"},
719+
"repair_history": [
720+
{
721+
"type": "ORIGINAL",
722+
"start_time": 1704528798059,
723+
"end_time": 1704529026679,
724+
"state": {
725+
"life_cycle_state": "TERMINATED",
726+
"result_state": "CANCELED",
727+
"state_message": "dummy_original",
728+
"user_cancelled_or_timedout": "false",
729+
},
730+
"task_run_ids": [396529700633015, 1111270934390307],
731+
},
732+
{
733+
"type": "REPAIR",
734+
"start_time": 1704530276423,
735+
"end_time": 1704530363736,
736+
"state": {
737+
"life_cycle_state": "TERMINATED",
738+
"result_state": "CANCELED",
739+
"state_message": "dummy_repair_1",
740+
"user_cancelled_or_timedout": "true",
741+
},
742+
"id": 108607572123234,
743+
"task_run_ids": [396529700633015, 1111270934390307],
744+
},
745+
{
746+
"type": "REPAIR",
747+
"start_time": 1704531464690,
748+
"end_time": 1704531481590,
749+
"state": {"life_cycle_state": "RUNNING", "result_state": "RUNNING"},
750+
"id": 52532060060836,
751+
"task_run_ids": [396529700633015, 1111270934390307],
752+
},
753+
],
754+
}
755+
latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)
756+
757+
assert latest_repair_id == 52532060060836
758+
686759
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
687760
def test_get_cluster_state(self, mock_requests):
688761
"""

0 commit comments

Comments
 (0)