3434from airflow .providers .databricks .hooks .databricks import (
3535 GET_RUN_ENDPOINT ,
3636 SUBMIT_RUN_ENDPOINT ,
37+ ClusterState ,
3738 DatabricksHook ,
3839 RunState ,
3940)
7879 "state" : {"life_cycle_state" : LIFE_CYCLE_STATE , "state_message" : STATE_MESSAGE },
7980}
8081GET_RUN_OUTPUT_RESPONSE = {"metadata" : {}, "error" : ERROR_MESSAGE , "notebook_output" : {}}
82+ CLUSTER_STATE = "TERMINATED"
83+ CLUSTER_STATE_MESSAGE = "Inactive cluster terminated (inactive for 120 minutes)."
84+ GET_CLUSTER_RESPONSE = {"state" : CLUSTER_STATE , "state_message" : CLUSTER_STATE_MESSAGE }
8185NOTEBOOK_PARAMS = {"dry-run" : "true" , "oldest-time-to-consider" : "1457570074236" }
8286JAR_PARAMS = ["param1" , "param2" ]
8387RESULT_STATE = ""
@@ -159,6 +163,13 @@ def repair_run_endpoint(host):
159163 return f"https://{ host } /api/2.1/jobs/runs/repair"
160164
161165
166+ def get_cluster_endpoint (host ):
167+ """
168+ Utility function to generate the get run endpoint given the host.
169+ """
170+ return f"https://{ host } /api/2.0/clusters/get"
171+
172+
162173def start_cluster_endpoint (host ):
163174 """
164175 Utility function to generate the get run endpoint given the host.
@@ -598,6 +609,26 @@ def test_repair_run(self, mock_requests):
598609 timeout = self .hook .timeout_seconds ,
599610 )
600611
612+ @mock .patch ("airflow.providers.databricks.hooks.databricks_base.requests" )
613+ def test_get_cluster_state (self , mock_requests ):
614+ """
615+ Response example from https://docs.databricks.com/api/workspace/clusters/get
616+ """
617+ mock_requests .codes .ok = 200
618+ mock_requests .get .return_value .json .return_value = GET_CLUSTER_RESPONSE
619+
620+ cluster_state = self .hook .get_cluster_state (CLUSTER_ID )
621+
622+ assert cluster_state == ClusterState (CLUSTER_STATE , CLUSTER_STATE_MESSAGE )
623+ mock_requests .get .assert_called_once_with (
624+ get_cluster_endpoint (HOST ),
625+ json = None ,
626+ params = {"cluster_id" : CLUSTER_ID },
627+ auth = HTTPBasicAuth (LOGIN , PASSWORD ),
628+ headers = self .hook .user_agent_header ,
629+ timeout = self .hook .timeout_seconds ,
630+ )
631+
601632 @mock .patch ("airflow.providers.databricks.hooks.databricks_base.requests" )
602633 def test_start_cluster (self , mock_requests ):
603634 mock_requests .codes .ok = 200
@@ -952,8 +983,8 @@ def test_is_terminal_false(self):
952983 assert not run_state .is_terminal
953984
954985 def test_is_terminal_with_nonexistent_life_cycle_state (self ):
955- run_state = RunState ("blah" , "" , "" )
956986 with pytest .raises (AirflowException ):
987+ run_state = RunState ("blah" , "" , "" )
957988 assert run_state .is_terminal
958989
959990 def test_is_successful (self ):
@@ -973,6 +1004,41 @@ def test_from_json(self):
9731004 assert expected == RunState .from_json (json .dumps (state ))
9741005
9751006
1007+ class TestClusterState :
1008+ def test_is_terminal_true (self ):
1009+ terminal_states = ["TERMINATING" , "TERMINATED" , "ERROR" , "UNKNOWN" ]
1010+ for state in terminal_states :
1011+ cluster_state = ClusterState (state , "" )
1012+ assert cluster_state .is_terminal
1013+
1014+ def test_is_terminal_false (self ):
1015+ non_terminal_states = ["PENDING" , "RUNNING" , "RESTARTING" , "RESIZING" ]
1016+ for state in non_terminal_states :
1017+ cluster_state = ClusterState (state , "" )
1018+ assert not cluster_state .is_terminal
1019+
1020+ def test_is_terminal_with_nonexistent_life_cycle_state (self ):
1021+ with pytest .raises (AirflowException ):
1022+ cluster_state = ClusterState ("blah" , "" )
1023+ assert cluster_state .is_terminal
1024+
1025+ def test_is_running (self ):
1026+ running_states = ["RUNNING" , "RESIZING" ]
1027+ for state in running_states :
1028+ cluster_state = ClusterState (state , "" )
1029+ assert cluster_state .is_running
1030+
1031+ def test_to_json (self ):
1032+ cluster_state = ClusterState (CLUSTER_STATE , CLUSTER_STATE_MESSAGE )
1033+ expected = json .dumps (GET_CLUSTER_RESPONSE )
1034+ assert expected == cluster_state .to_json ()
1035+
1036+ def test_from_json (self ):
1037+ state = GET_CLUSTER_RESPONSE
1038+ expected = ClusterState (CLUSTER_STATE , CLUSTER_STATE_MESSAGE )
1039+ assert expected == ClusterState .from_json (json .dumps (state ))
1040+
1041+
9761042def create_aad_token_for_resource (resource : str ) -> dict :
9771043 return {
9781044 "token_type" : "Bearer" ,
@@ -1284,6 +1350,23 @@ async def test_get_run_state(self, mock_get):
12841350 timeout = self .hook .timeout_seconds ,
12851351 )
12861352
1353+ @pytest .mark .asyncio
1354+ @mock .patch ("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get" )
1355+ async def test_get_cluster_state (self , mock_get ):
1356+ mock_get .return_value .__aenter__ .return_value .json = AsyncMock (return_value = GET_CLUSTER_RESPONSE )
1357+
1358+ async with self .hook :
1359+ cluster_state = await self .hook .a_get_cluster_state (CLUSTER_ID )
1360+
1361+ assert cluster_state == ClusterState (CLUSTER_STATE , CLUSTER_STATE_MESSAGE )
1362+ mock_get .assert_called_once_with (
1363+ get_cluster_endpoint (HOST ),
1364+ json = {"cluster_id" : CLUSTER_ID },
1365+ auth = aiohttp .BasicAuth (LOGIN , PASSWORD ),
1366+ headers = self .hook .user_agent_header ,
1367+ timeout = self .hook .timeout_seconds ,
1368+ )
1369+
12871370
12881371class TestDatabricksHookAsyncAadToken :
12891372 """
0 commit comments