Skip to content

Commit 946b539

Browse files
authored
Add DatabricksHook ClusterState (#34643)
* implement ClusterState and get_cluster_state() method
1 parent 6ba2c44 commit 946b539

2 files changed

Lines changed: 177 additions & 17 deletions

File tree

airflow/providers/databricks/hooks/databricks.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
3737
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook
3838

39+
GET_CLUSTER_ENDPOINT = ("GET", "api/2.0/clusters/get")
3940
RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
4041
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
4142
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
@@ -57,38 +58,39 @@
5758

5859
WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
5960

60-
RUN_LIFE_CYCLE_STATES = [
61-
"PENDING",
62-
"RUNNING",
63-
"TERMINATING",
64-
"TERMINATED",
65-
"SKIPPED",
66-
"INTERNAL_ERROR",
67-
"QUEUED",
68-
]
69-
7061
SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
7162

7263

7364
class RunState:
7465
"""Utility class for the run state concept of Databricks runs."""
7566

67+
RUN_LIFE_CYCLE_STATES = [
68+
"PENDING",
69+
"RUNNING",
70+
"TERMINATING",
71+
"TERMINATED",
72+
"SKIPPED",
73+
"INTERNAL_ERROR",
74+
"QUEUED",
75+
]
76+
7677
def __init__(
7778
self, life_cycle_state: str, result_state: str = "", state_message: str = "", *args, **kwargs
7879
) -> None:
80+
if life_cycle_state not in self.RUN_LIFE_CYCLE_STATES:
81+
raise AirflowException(
82+
f"Unexpected life cycle state: {life_cycle_state}: If the state has "
83+
"been introduced recently, please check the Databricks user "
84+
"guide for troubleshooting information"
85+
)
86+
7987
self.life_cycle_state = life_cycle_state
8088
self.result_state = result_state
8189
self.state_message = state_message
8290

8391
@property
8492
def is_terminal(self) -> bool:
8593
"""True if the current state is a terminal state."""
86-
if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
87-
raise AirflowException(
88-
f"Unexpected life cycle state: {self.life_cycle_state}: If the state has "
89-
"been introduced recently, please check the Databricks user "
90-
"guide for troubleshooting information"
91-
)
9294
return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR")
9395

9496
@property
@@ -116,6 +118,55 @@ def from_json(cls, data: str) -> RunState:
116118
return RunState(**json.loads(data))
117119

118120

121+
class ClusterState:
122+
"""Utility class for the cluster state concept of Databricks cluster."""
123+
124+
CLUSTER_LIFE_CYCLE_STATES = [
125+
"PENDING",
126+
"RUNNING",
127+
"RESTARTING",
128+
"RESIZING",
129+
"TERMINATING",
130+
"TERMINATED",
131+
"ERROR",
132+
"UNKNOWN",
133+
]
134+
135+
def __init__(self, state: str = "", state_message: str = "", *args, **kwargs) -> None:
136+
if state not in self.CLUSTER_LIFE_CYCLE_STATES:
137+
raise AirflowException(
138+
f"Unexpected cluster life cycle state: {state}: If the state has "
139+
"been introduced recently, please check the Databricks user "
140+
"guide for troubleshooting information"
141+
)
142+
143+
self.state = state
144+
self.state_message = state_message
145+
146+
@property
147+
def is_terminal(self) -> bool:
148+
"""True if the current state is a terminal state."""
149+
return self.state in ("TERMINATING", "TERMINATED", "ERROR", "UNKNOWN")
150+
151+
@property
152+
def is_running(self) -> bool:
153+
"""True if the current state is running."""
154+
return self.state in ("RUNNING", "RESIZING")
155+
156+
def __eq__(self, other) -> bool:
157+
return self.state == other.state and self.state_message == other.state_message
158+
159+
def __repr__(self) -> str:
160+
return str(self.__dict__)
161+
162+
def to_json(self) -> str:
163+
return json.dumps(self.__dict__)
164+
165+
@classmethod
166+
def from_json(cls, data: str) -> ClusterState:
167+
return ClusterState(**json.loads(data))
168+
169+
119170
class DatabricksHook(BaseDatabricksHook):
120171
"""
121172
Interact with Databricks.
@@ -474,6 +525,32 @@ def repair_run(self, json: dict) -> None:
474525
"""
475526
self._do_api_call(REPAIR_RUN_ENDPOINT, json)
476527

528+
def get_cluster_state(self, cluster_id: str) -> ClusterState:
529+
"""
530+
Retrieves run state of the cluster.
531+
532+
:param cluster_id: id of the cluster
533+
:return: state of the cluster
534+
"""
535+
json = {"cluster_id": cluster_id}
536+
response = self._do_api_call(GET_CLUSTER_ENDPOINT, json)
537+
state = response["state"]
538+
state_message = response["state_message"]
539+
return ClusterState(state, state_message)
540+
541+
async def a_get_cluster_state(self, cluster_id: str) -> ClusterState:
542+
"""
543+
Async version of `get_cluster_state`.
544+
545+
:param cluster_id: id of the cluster
546+
:return: state of the cluster
547+
"""
548+
json = {"cluster_id": cluster_id}
549+
response = await self._a_do_api_call(GET_CLUSTER_ENDPOINT, json)
550+
state = response["state"]
551+
state_message = response["state_message"]
552+
return ClusterState(state, state_message)
553+
477554
def restart_cluster(self, json: dict) -> None:
478555
"""
479556
Restarts the cluster.

tests/providers/databricks/hooks/test_databricks.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from airflow.providers.databricks.hooks.databricks import (
3535
GET_RUN_ENDPOINT,
3636
SUBMIT_RUN_ENDPOINT,
37+
ClusterState,
3738
DatabricksHook,
3839
RunState,
3940
)
@@ -78,6 +79,9 @@
7879
"state": {"life_cycle_state": LIFE_CYCLE_STATE, "state_message": STATE_MESSAGE},
7980
}
8081
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}}
82+
CLUSTER_STATE = "TERMINATED"
83+
CLUSTER_STATE_MESSAGE = "Inactive cluster terminated (inactive for 120 minutes)."
84+
GET_CLUSTER_RESPONSE = {"state": CLUSTER_STATE, "state_message": CLUSTER_STATE_MESSAGE}
8185
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
8286
JAR_PARAMS = ["param1", "param2"]
8387
RESULT_STATE = ""
@@ -159,6 +163,13 @@ def repair_run_endpoint(host):
159163
return f"https://{host}/api/2.1/jobs/runs/repair"
160164

161165

166+
def get_cluster_endpoint(host):
167+
"""
168+
Utility function to generate the get run endpoint given the host.
169+
"""
170+
return f"https://{host}/api/2.0/clusters/get"
171+
172+
162173
def start_cluster_endpoint(host):
163174
"""
164175
Utility function to generate the get run endpoint given the host.
@@ -598,6 +609,26 @@ def test_repair_run(self, mock_requests):
598609
timeout=self.hook.timeout_seconds,
599610
)
600611

612+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
613+
def test_get_cluster_state(self, mock_requests):
614+
"""
615+
Response example from https://docs.databricks.com/api/workspace/clusters/get
616+
"""
617+
mock_requests.codes.ok = 200
618+
mock_requests.get.return_value.json.return_value = GET_CLUSTER_RESPONSE
619+
620+
cluster_state = self.hook.get_cluster_state(CLUSTER_ID)
621+
622+
assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
623+
mock_requests.get.assert_called_once_with(
624+
get_cluster_endpoint(HOST),
625+
json=None,
626+
params={"cluster_id": CLUSTER_ID},
627+
auth=HTTPBasicAuth(LOGIN, PASSWORD),
628+
headers=self.hook.user_agent_header,
629+
timeout=self.hook.timeout_seconds,
630+
)
631+
601632
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
602633
def test_start_cluster(self, mock_requests):
603634
mock_requests.codes.ok = 200
@@ -952,8 +983,8 @@ def test_is_terminal_false(self):
952983
assert not run_state.is_terminal
953984

954985
def test_is_terminal_with_nonexistent_life_cycle_state(self):
955-
run_state = RunState("blah", "", "")
956986
with pytest.raises(AirflowException):
987+
run_state = RunState("blah", "", "")
957988
assert run_state.is_terminal
958989

959990
def test_is_successful(self):
@@ -973,6 +1004,41 @@ def test_from_json(self):
9731004
assert expected == RunState.from_json(json.dumps(state))
9741005

9751006

1007+
class TestClusterState:
1008+
def test_is_terminal_true(self):
1009+
terminal_states = ["TERMINATING", "TERMINATED", "ERROR", "UNKNOWN"]
1010+
for state in terminal_states:
1011+
cluster_state = ClusterState(state, "")
1012+
assert cluster_state.is_terminal
1013+
1014+
def test_is_terminal_false(self):
1015+
non_terminal_states = ["PENDING", "RUNNING", "RESTARTING", "RESIZING"]
1016+
for state in non_terminal_states:
1017+
cluster_state = ClusterState(state, "")
1018+
assert not cluster_state.is_terminal
1019+
1020+
def test_is_terminal_with_nonexistent_life_cycle_state(self):
1021+
with pytest.raises(AirflowException):
1022+
cluster_state = ClusterState("blah", "")
1023+
assert cluster_state.is_terminal
1024+
1025+
def test_is_running(self):
1026+
running_states = ["RUNNING", "RESIZING"]
1027+
for state in running_states:
1028+
cluster_state = ClusterState(state, "")
1029+
assert cluster_state.is_running
1030+
1031+
def test_to_json(self):
1032+
cluster_state = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
1033+
expected = json.dumps(GET_CLUSTER_RESPONSE)
1034+
assert expected == cluster_state.to_json()
1035+
1036+
def test_from_json(self):
1037+
state = GET_CLUSTER_RESPONSE
1038+
expected = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
1039+
assert expected == ClusterState.from_json(json.dumps(state))
1040+
1041+
9761042
def create_aad_token_for_resource(resource: str) -> dict:
9771043
return {
9781044
"token_type": "Bearer",
@@ -1284,6 +1350,23 @@ async def test_get_run_state(self, mock_get):
12841350
timeout=self.hook.timeout_seconds,
12851351
)
12861352

1353+
@pytest.mark.asyncio
1354+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
1355+
async def test_get_cluster_state(self, mock_get):
1356+
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_CLUSTER_RESPONSE)
1357+
1358+
async with self.hook:
1359+
cluster_state = await self.hook.a_get_cluster_state(CLUSTER_ID)
1360+
1361+
assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
1362+
mock_get.assert_called_once_with(
1363+
get_cluster_endpoint(HOST),
1364+
json={"cluster_id": CLUSTER_ID},
1365+
auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
1366+
headers=self.hook.user_agent_header,
1367+
timeout=self.hook.timeout_seconds,
1368+
)
1369+
12871370

12881371
class TestDatabricksHookAsyncAadToken:
12891372
"""

0 commit comments

Comments
 (0)