|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | 19 | import base64 |
| 20 | +import time |
20 | 21 | import uuid |
21 | 22 | import warnings |
22 | 23 | from datetime import timedelta |
@@ -313,6 +314,80 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]: |
313 | 314 | status_code, resp = self._make_api_call_with_retries("GET", url, header, params) |
314 | 315 | return self._process_response(status_code, resp) |
315 | 316 |
|
| 317 | + def wait_for_query( |
| 318 | + self, query_id: str, raise_error: bool = False, poll_interval: int = 5, timeout: int = 60 |
| 319 | + ) -> dict[str, str | list[str]]: |
| 320 | + """ |
| 321 | + Wait for query to finish either successfully or with error. |
| 322 | +
|
| 323 | + :param query_id: statement handle id for the individual statement. |
| 324 | + :param raise_error: whether to raise an error if the query failed. |
| 325 | + :param poll_interval: time (in seconds) between checking the query status. |
| 326 | + :param timeout: max time (in seconds) to wait for the query to finish before raising a TimeoutError. |
| 327 | +
|
| 328 | + :raises RuntimeError: If the query status is 'error' and `raise_error` is True. |
| 329 | + :raises TimeoutError: If the query doesn't finish within the specified timeout. |
| 330 | + """ |
| 331 | + start_time = time.time() |
| 332 | + |
| 333 | + while True: |
| 334 | + response = self.get_sql_api_query_status(query_id=query_id) |
| 335 | + self.log.debug("Query status `%s`", response["status"]) |
| 336 | + |
| 337 | + if time.time() - start_time > timeout: |
| 338 | + raise TimeoutError( |
| 339 | + f"Query `{query_id}` did not finish within the timeout period of {timeout} seconds." |
| 340 | + ) |
| 341 | + |
| 342 | + if response["status"] != "running": |
| 343 | + self.log.info("Query status `%s`", response["status"]) |
| 344 | + break |
| 345 | + |
| 346 | + time.sleep(poll_interval) |
| 347 | + |
| 348 | + if response["status"] == "error" and raise_error: |
| 349 | + raise RuntimeError(response["message"]) |
| 350 | + |
| 351 | + return response |
| 352 | + |
| 353 | + def get_result_from_successful_sql_api_query(self, query_id: str) -> list[dict[str, Any]]: |
| 354 | + """ |
| 355 | + Based on the query id HTTP requests are made to snowflake SQL API and return result data. |
| 356 | +
|
| 357 | + :param query_id: statement handle id for the individual statement. |
| 358 | +
|
| 359 | + :raises RuntimeError: If the query status is not 'success'. |
| 360 | + """ |
| 361 | + self.log.info("Retrieving data for query id %s", query_id) |
| 362 | + header, params, url = self.get_request_url_header_params(query_id) |
| 363 | + status_code, response = self._make_api_call_with_retries("GET", url, header, params) |
| 364 | + |
| 365 | + if (query_status := self._process_response(status_code, response)["status"]) != "success": |
| 366 | + msg = f"Query must have status `success` to retrieve data; got `{query_status}`." |
| 367 | + raise RuntimeError(msg) |
| 368 | + |
| 369 | + # Below fields should always be present in response, but added some safety checks |
| 370 | + data = response.get("data", []) |
| 371 | + if not data: |
| 372 | + self.log.warning("No data found in the API response.") |
| 373 | + return [] |
| 374 | + metadata = response.get("resultSetMetaData", {}) |
| 375 | + col_names = [row["name"] for row in metadata.get("rowType", [])] |
| 376 | + if not col_names: |
| 377 | + self.log.warning("No column metadata found in the API response.") |
| 378 | + return [] |
| 379 | + |
| 380 | + num_partitions = len(metadata.get("partitionInfo", [])) |
| 381 | + if num_partitions > 1: |
| 382 | + self.log.debug("Result data is returned as multiple partitions. Will perform additional queries.") |
| 383 | + url += "?partition=" |
| 384 | + for partition_no in range(1, num_partitions): # First partition was already returned |
| 385 | + self.log.debug("Querying for partition no. %s", partition_no) |
| 386 | + _, response = self._make_api_call_with_retries("GET", url + str(partition_no), header, params) |
| 387 | + data.extend(response.get("data", [])) |
| 388 | + |
| 389 | + return [dict(zip(col_names, row)) for row in data] # Merged column names with data |
| 390 | + |
316 | 391 | async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]: |
317 | 392 | """ |
318 | 393 | Based on the query id async HTTP request is made to snowflake SQL API and return response. |
|
0 commit comments