Skip to content

Commit 40a3c7e

Browse files
authored
feat: Add new query related methods to SnowflakeSqlApiHook (#52157)
1 parent 266c1ce commit 40a3c7e

2 files changed

Lines changed: 337 additions & 69 deletions

File tree

providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import base64
20+
import time
2021
import uuid
2122
import warnings
2223
from datetime import timedelta
@@ -313,6 +314,80 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
313314
status_code, resp = self._make_api_call_with_retries("GET", url, header, params)
314315
return self._process_response(status_code, resp)
315316

317+
def wait_for_query(
318+
self, query_id: str, raise_error: bool = False, poll_interval: int = 5, timeout: int = 60
319+
) -> dict[str, str | list[str]]:
320+
"""
321+
Wait for query to finish either successfully or with error.
322+
323+
:param query_id: statement handle id for the individual statement.
324+
:param raise_error: whether to raise an error if the query failed.
325+
:param poll_interval: time (in seconds) between checking the query status.
326+
:param timeout: max time (in seconds) to wait for the query to finish before raising a TimeoutError.
327+
328+
:raises RuntimeError: If the query status is 'error' and `raise_error` is True.
329+
:raises TimeoutError: If the query doesn't finish within the specified timeout.
330+
"""
331+
start_time = time.time()
332+
333+
while True:
334+
response = self.get_sql_api_query_status(query_id=query_id)
335+
self.log.debug("Query status `%s`", response["status"])
336+
337+
if time.time() - start_time > timeout:
338+
raise TimeoutError(
339+
f"Query `{query_id}` did not finish within the timeout period of {timeout} seconds."
340+
)
341+
342+
if response["status"] != "running":
343+
self.log.info("Query status `%s`", response["status"])
344+
break
345+
346+
time.sleep(poll_interval)
347+
348+
if response["status"] == "error" and raise_error:
349+
raise RuntimeError(response["message"])
350+
351+
return response
352+
353+
def get_result_from_successful_sql_api_query(self, query_id: str) -> list[dict[str, Any]]:
354+
"""
355+
Based on the query id HTTP requests are made to snowflake SQL API and return result data.
356+
357+
:param query_id: statement handle id for the individual statement.
358+
359+
:raises RuntimeError: If the query status is not 'success'.
360+
"""
361+
self.log.info("Retrieving data for query id %s", query_id)
362+
header, params, url = self.get_request_url_header_params(query_id)
363+
status_code, response = self._make_api_call_with_retries("GET", url, header, params)
364+
365+
if (query_status := self._process_response(status_code, response)["status"]) != "success":
366+
msg = f"Query must have status `success` to retrieve data; got `{query_status}`."
367+
raise RuntimeError(msg)
368+
369+
# Below fields should always be present in response, but added some safety checks
370+
data = response.get("data", [])
371+
if not data:
372+
self.log.warning("No data found in the API response.")
373+
return []
374+
metadata = response.get("resultSetMetaData", {})
375+
col_names = [row["name"] for row in metadata.get("rowType", [])]
376+
if not col_names:
377+
self.log.warning("No column metadata found in the API response.")
378+
return []
379+
380+
num_partitions = len(metadata.get("partitionInfo", []))
381+
if num_partitions > 1:
382+
self.log.debug("Result data is returned as multiple partitions. Will perform additional queries.")
383+
url += "?partition="
384+
for partition_no in range(1, num_partitions): # First partition was already returned
385+
self.log.debug("Querying for partition no. %s", partition_no)
386+
_, response = self._make_api_call_with_retries("GET", url + str(partition_no), header, params)
387+
data.extend(response.get("data", []))
388+
389+
return [dict(zip(col_names, row)) for row in data] # Merged column names with data
390+
316391
async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
317392
"""
318393
Based on the query id async HTTP request is made to snowflake SQL API and return response.

0 commit comments

Comments
 (0)