From a4c6bdb916d918c377b83064fe0330931de423f3 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 22 Aug 2025 00:13:23 +0530 Subject: [PATCH 1/3] FEAT: Adding gettypeInfo --- mssql_python/cursor.py | 75 ++++++++++++ mssql_python/pybind/ddbc_bindings.cpp | 16 ++- mssql_python/pybind/ddbc_bindings.h | 2 + tests/test_004_cursor.py | 169 ++++++++++++++++++++++++++ 4 files changed, 260 insertions(+), 2 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 9183da98b..7a10be112 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -723,6 +723,81 @@ def execute( self._reset_inputsizes() # Reset input sizes after execution + def getTypeInfo(self, sqlType=None): + """ + Executes SQLGetTypeInfo and creates a result set with information about + the specified data type or all data types supported by the ODBC driver if not specified. + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # SQL_ALL_TYPES = 0 + sql_all_types = 0 + + try: + if sqlType is None: + # Get information about all data types + ret = ddbc_bindings.DDBCSQLGetTypeInfo(self.hstmt, sql_all_types) + else: + # Get information about specified data type + ret = ddbc_bindings.DDBCSQLGetTypeInfo(self.hstmt, sqlType) + + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + # Initialize the description based on result set metadata + column_metadata = [] + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + + # Initialize the description attribute with the column metadata + self._initialize_description(column_metadata) + + # Fetch all rows first + rows_data = [] + ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + # If we have no rows, return an empty list + if not rows_data: + return [] + + # Create a custom column map for our Row objects + column_map = { + 'type_name': 0, + 'data_type': 1, + 'column_size': 2, + 'literal_prefix': 3, + 'literal_suffix': 4, + 'create_params': 5, + 'nullable': 6, + 'case_sensitive': 7, + 'searchable': 8, + 'unsigned_attribute': 9, + 'fixed_prec_scale': 10, + 'auto_unique_value': 11, + 'local_type_name': 12, + 'minimum_scale': 13, + 'maximum_scale': 14, + 'sql_data_type': 15, + 'sql_datetime_sub': 16, + 'num_prec_radix': 17, + 'interval_precision': 18 + } + + # Create result rows with the custom column map + result_rows = [] + for row_data in rows_data: + row = Row(self, self.description, row_data) + # Manually add the column map + row._column_map = column_map + result_rows.append(row) + + return result_rows + except Exception as e: + # Always reset the cursor on exception + self._reset_cursor() + raise e + @staticmethod def _select_best_sample_value(column): """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 6a8a0187a..2d176b7a8 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -123,6 +123,7 @@ SQLBindColFunc SQLBindCol_ptr = nullptr; SQLDescribeColFunc SQLDescribeCol_ptr = nullptr; SQLMoreResultsFunc SQLMoreResults_ptr = nullptr; SQLColAttributeFunc SQLColAttribute_ptr = nullptr; +SQLGetTypeInfoFunc SQLGetTypeInfo_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -779,6 +780,7 @@ DriverHandle LoadDriverOrThrowException() { SQLDescribeCol_ptr = GetFunctionPointer(handle, "SQLDescribeColW"); SQLMoreResults_ptr = GetFunctionPointer(handle, "SQLMoreResults"); SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); + SQLGetTypeInfo_ptr = GetFunctionPointer(handle, "SQLGetTypeInfoW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -796,7 +798,8 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && + SQLGetTypeInfo_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -861,6 +864,14 @@ void SqlHandle::free() { } } +SQLRETURN SQLGetTypeInfo_Wrapper(SqlHandlePtr StatementHandle, SQLSMALLINT DataType) { + if (!SQLGetTypeInfo_ptr) { + ThrowStdException("SQLGetTypeInfo function not loaded"); + } + + return SQLGetTypeInfo_ptr(StatementHandle->get(), DataType); +} + // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -2579,7 +2590,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLSetStmtAttr", [](SqlHandlePtr stmt, SQLINTEGER attr, SQLPOINTER value) { return SQLSetStmtAttr_ptr(stmt->get(), attr, value, 0); }, "Set statement attributes"); - + m.def("DDBCSQLGetTypeInfo", &SQLGetTypeInfo_Wrapper, "Returns information about the data types that are supported by the data source", + py::arg("StatementHandle"), py::arg("DataType")); // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index d142276c6..98d9f9f5b 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -105,6 +105,7 @@ typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); +typedef SQLRETURN (SQL_API* SQLGetTypeInfoFunc)(SQLHSTMT, SQLSMALLINT); // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -148,6 +149,7 @@ extern SQLBindColFunc SQLBindCol_ptr; extern SQLDescribeColFunc SQLDescribeCol_ptr; extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; +extern SQLGetTypeInfoFunc SQLGetTypeInfo_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index b9ac5a452..b18660659 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1739,6 +1739,175 @@ def test_cursor_setinputsizes_override_inference(db_connection): # Clean up cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_override") +def test_gettypeinfo_all_types(cursor): + """Test getTypeInfo with no arguments returns all data types""" + # Get all type information + type_info = cursor.getTypeInfo() + + # Verify we got results + assert type_info is not None, "getTypeInfo() should return results" + assert len(type_info) > 0, "getTypeInfo() should return at least one data type" + + # Verify common data types are present + type_names = [str(row.type_name).upper() for row in type_info] + assert any('VARCHAR' in name for name in type_names), "VARCHAR type should be in results" + assert any('INT' in name for name in type_names), "INTEGER type should be in results" + + # Verify first row has expected columns + first_row = type_info[0] + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + assert hasattr(first_row, 'data_type'), "Result should have data_type column" + assert hasattr(first_row, 'column_size'), "Result should have column_size column" + assert hasattr(first_row, 'nullable'), "Result should have nullable column" + +def test_gettypeinfo_specific_type(cursor): + """Test getTypeInfo with specific type argument""" + from mssql_python.constants import ConstantsDDBC + + # Test with VARCHAR type (SQL_VARCHAR) + varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + + # Verify we got results specific to VARCHAR + assert varchar_info is not None, "getTypeInfo(SQL_VARCHAR) should return results" + assert len(varchar_info) > 0, "getTypeInfo(SQL_VARCHAR) should return at least one row" + + # All rows should be related to VARCHAR type + for row in varchar_info: + assert 'varchar' in row.type_name or 'char' in row.type_name, \ + f"Expected VARCHAR type, got {row.type_name}" + assert row.data_type == ConstantsDDBC.SQL_VARCHAR.value, \ + f"Expected data_type={ConstantsDDBC.SQL_VARCHAR.value}, got {row.data_type}" + +def test_gettypeinfo_result_structure(cursor): + """Test the structure of getTypeInfo result rows""" + # Get info for a common type like INTEGER + from mssql_python.constants import ConstantsDDBC + + int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value) + + # Make sure we have at least one result + assert len(int_info) > 0, "getTypeInfo for INTEGER should return results" + + # Check for all required columns in the result + first_row = int_info[0] + required_columns = [ + 'type_name', 'data_type', 'column_size', 'literal_prefix', + 'literal_suffix', 'create_params', 'nullable', 'case_sensitive', + 'searchable', 'unsigned_attribute', 'fixed_prec_scale', + 'auto_unique_value', 'local_type_name', 'minimum_scale', + 'maximum_scale', 'sql_data_type', 'sql_datetime_sub', + 'num_prec_radix', 'interval_precision' + ] + + for column in required_columns: + assert hasattr(first_row, column), f"Result missing required column: {column}" + +def test_gettypeinfo_numeric_type(cursor): + """Test getTypeInfo for numeric data types""" + from mssql_python.constants import ConstantsDDBC + + # Get information about DECIMAL type + decimal_info = cursor.getTypeInfo(ConstantsDDBC.SQL_DECIMAL.value) + + # Verify decimal-specific attributes + assert len(decimal_info) > 0, "getTypeInfo for DECIMAL should return results" + + decimal_row = decimal_info[0] + # DECIMAL should have precision and scale parameters + assert decimal_row.create_params is not None, "DECIMAL should have create_params" + assert "PRECISION" in decimal_row.create_params.upper() or \ + "SCALE" in decimal_row.create_params.upper(), \ + "DECIMAL create_params should mention precision/scale" + + # Numeric types typically use base 10 for the num_prec_radix + assert decimal_row.num_prec_radix == 10, \ + f"Expected num_prec_radix=10 for DECIMAL, got {decimal_row.num_prec_radix}" + +def test_gettypeinfo_datetime_types(cursor): + """Test getTypeInfo for datetime types""" + from mssql_python.constants import ConstantsDDBC + + # Get information about TIMESTAMP type instead of DATETIME + # SQL_TYPE_TIMESTAMP (93) is more commonly used for datetime in ODBC + datetime_info = cursor.getTypeInfo(ConstantsDDBC.SQL_TYPE_TIMESTAMP.value) + + # Verify we got datetime-related results + assert len(datetime_info) > 0, "getTypeInfo for TIMESTAMP should return results" + + # Check for datetime-specific attributes + first_row = datetime_info[0] + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + + # Datetime type names often contain 'date', 'time', or 'datetime' + type_name_lower = first_row.type_name.lower() + assert any(term in type_name_lower for term in ['date', 'time', 'timestamp', 'datetime']), \ + f"Expected datetime-related type name, got {first_row.type_name}" + +def test_gettypeinfo_multiple_calls(cursor): + """Test calling getTypeInfo multiple times in succession""" + from mssql_python.constants import ConstantsDDBC + + # First call - get all types + all_types = cursor.getTypeInfo() + assert len(all_types) > 0, "First call to getTypeInfo should return results" + + # Second call - get VARCHAR type + varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + assert len(varchar_info) > 0, "Second call to getTypeInfo should return results" + + # Third call - get INTEGER type + int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value) + assert len(int_info) > 0, "Third call to getTypeInfo should return results" + + # Verify the results are different between calls + assert len(all_types) > len(varchar_info), "All types should return more rows than specific type" + +def test_gettypeinfo_binary_types(cursor): + """Test getTypeInfo for binary data types""" + from mssql_python.constants import ConstantsDDBC + + # Get information about BINARY or VARBINARY type + binary_info = cursor.getTypeInfo(ConstantsDDBC.SQL_BINARY.value) + + # Verify we got binary-related results + assert len(binary_info) > 0, "getTypeInfo for BINARY should return results" + + # Check for binary-specific attributes + for row in binary_info: + type_name_lower = row.type_name.lower() + # Include 'timestamp' as SQL Server reports it as a binary type + assert any(term in type_name_lower for term in ['binary', 'blob', 'image', 'timestamp']), \ + f"Expected binary-related type name, got {row.type_name}" + + # Binary types typically don't support case sensitivity + assert row.case_sensitive == 0, f"Binary types should not be case sensitive, got {row.case_sensitive}" + +def test_gettypeinfo_cached_results(cursor): + """Test that multiple identical calls to getTypeInfo are efficient""" + from mssql_python.constants import ConstantsDDBC + import time + + # First call - might be slower + start_time = time.time() + first_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + first_duration = time.time() - start_time + + # Give the system a moment + time.sleep(0.1) + + # Second call with same type - should be similar or faster + start_time = time.time() + second_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + second_duration = time.time() - start_time + + # Results should be consistent + assert len(first_result) == len(second_result), "Multiple calls should return same number of results" + + # Both calls should return the correct type info + for row in second_result: + assert row.data_type == ConstantsDDBC.SQL_VARCHAR.value, \ + f"Expected SQL_VARCHAR type, got {row.data_type}" + def test_close(db_connection): """Test closing the cursor""" try: From 11153ac09105268c19d5e42293acb2c3b4046526 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 12 Sep 2025 14:14:44 +0530 Subject: [PATCH 2/3] Resolving comments --- mssql_python/cursor.py | 90 ++++++++++++++++++++++------------------ tests/test_004_cursor.py | 22 +++++----- 2 files changed, 61 insertions(+), 51 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 7a10be112..41d633dce 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -639,6 +639,16 @@ def execute( use_prepare: Whether to use SQLPrepareW (default) or SQLExecDirectW. reset_cursor: Whether to reset the cursor before execution. """ + + # Restore original fetch methods if they exist + if hasattr(self, '_original_fetchone'): + self.fetchone = self._original_fetchone + self.fetchmany = self._original_fetchmany + self.fetchall = self._original_fetchall + del self._original_fetchone + del self._original_fetchmany + del self._original_fetchall + self._check_closed() # Check if the cursor is closed if reset_cursor: self._reset_cursor() @@ -753,46 +763,46 @@ def getTypeInfo(self, sqlType=None): # Initialize the description attribute with the column metadata self._initialize_description(column_metadata) - # Fetch all rows first - rows_data = [] - ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) - - # If we have no rows, return an empty list - if not rows_data: - return [] - - # Create a custom column map for our Row objects - column_map = { - 'type_name': 0, - 'data_type': 1, - 'column_size': 2, - 'literal_prefix': 3, - 'literal_suffix': 4, - 'create_params': 5, - 'nullable': 6, - 'case_sensitive': 7, - 'searchable': 8, - 'unsigned_attribute': 9, - 'fixed_prec_scale': 10, - 'auto_unique_value': 11, - 'local_type_name': 12, - 'minimum_scale': 13, - 'maximum_scale': 14, - 'sql_data_type': 15, - 'sql_datetime_sub': 16, - 'num_prec_radix': 17, - 'interval_precision': 18 - } - - # Create result rows with the custom column map - result_rows = [] - for row_data in rows_data: - row = Row(self, self.description, row_data) - # Manually add the column map - row._column_map = column_map - result_rows.append(row) - - return result_rows + # Define column names in ODBC standard order + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self except Exception as e: # Always reset the cursor on exception self._reset_cursor() diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index b18660659..619cc2304 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1742,7 +1742,7 @@ def test_cursor_setinputsizes_override_inference(db_connection): def test_gettypeinfo_all_types(cursor): """Test getTypeInfo with no arguments returns all data types""" # Get all type information - type_info = cursor.getTypeInfo() + type_info = cursor.getTypeInfo().fetchall() # Verify we got results assert type_info is not None, "getTypeInfo() should return results" @@ -1765,7 +1765,7 @@ def test_gettypeinfo_specific_type(cursor): from mssql_python.constants import ConstantsDDBC # Test with VARCHAR type (SQL_VARCHAR) - varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() # Verify we got results specific to VARCHAR assert varchar_info is not None, "getTypeInfo(SQL_VARCHAR) should return results" @@ -1783,7 +1783,7 @@ def test_gettypeinfo_result_structure(cursor): # Get info for a common type like INTEGER from mssql_python.constants import ConstantsDDBC - int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value) + int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value).fetchall() # Make sure we have at least one result assert len(int_info) > 0, "getTypeInfo for INTEGER should return results" @@ -1807,7 +1807,7 @@ def test_gettypeinfo_numeric_type(cursor): from mssql_python.constants import ConstantsDDBC # Get information about DECIMAL type - decimal_info = cursor.getTypeInfo(ConstantsDDBC.SQL_DECIMAL.value) + decimal_info = cursor.getTypeInfo(ConstantsDDBC.SQL_DECIMAL.value).fetchall() # Verify decimal-specific attributes assert len(decimal_info) > 0, "getTypeInfo for DECIMAL should return results" @@ -1829,7 +1829,7 @@ def test_gettypeinfo_datetime_types(cursor): # Get information about TIMESTAMP type instead of DATETIME # SQL_TYPE_TIMESTAMP (93) is more commonly used for datetime in ODBC - datetime_info = cursor.getTypeInfo(ConstantsDDBC.SQL_TYPE_TIMESTAMP.value) + datetime_info = cursor.getTypeInfo(ConstantsDDBC.SQL_TYPE_TIMESTAMP.value).fetchall() # Verify we got datetime-related results assert len(datetime_info) > 0, "getTypeInfo for TIMESTAMP should return results" @@ -1848,15 +1848,15 @@ def test_gettypeinfo_multiple_calls(cursor): from mssql_python.constants import ConstantsDDBC # First call - get all types - all_types = cursor.getTypeInfo() + all_types = cursor.getTypeInfo().fetchall() assert len(all_types) > 0, "First call to getTypeInfo should return results" # Second call - get VARCHAR type - varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() assert len(varchar_info) > 0, "Second call to getTypeInfo should return results" # Third call - get INTEGER type - int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value) + int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value).fetchall() assert len(int_info) > 0, "Third call to getTypeInfo should return results" # Verify the results are different between calls @@ -1867,7 +1867,7 @@ def test_gettypeinfo_binary_types(cursor): from mssql_python.constants import ConstantsDDBC # Get information about BINARY or VARBINARY type - binary_info = cursor.getTypeInfo(ConstantsDDBC.SQL_BINARY.value) + binary_info = cursor.getTypeInfo(ConstantsDDBC.SQL_BINARY.value).fetchall() # Verify we got binary-related results assert len(binary_info) > 0, "getTypeInfo for BINARY should return results" @@ -1889,7 +1889,7 @@ def test_gettypeinfo_cached_results(cursor): # First call - might be slower start_time = time.time() - first_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + first_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() first_duration = time.time() - start_time # Give the system a moment @@ -1897,7 +1897,7 @@ def test_gettypeinfo_cached_results(cursor): # Second call with same type - should be similar or faster start_time = time.time() - second_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value) + second_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() second_duration = time.time() - start_time # Results should be consistent From 52996ac85d5e7cf9b0cebd455ef9f06906644a5a Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar <61936179+jahnvi480@users.noreply.github.com> Date: Mon, 15 Sep 2025 10:34:23 +0530 Subject: [PATCH 3/3] FEAT: Adding procedures API (#194) ### Work Item / Issue Reference > [AB#34933](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34933) ------------------------------------------------------------------- ### Summary This pull request adds support for retrieving metadata about stored procedures in SQL Server via the Python driver, including both temporary and permanent procedures. It introduces a new `procedures()` method on the `cursor` object, updates the C++ pybind layer to expose the ODBC `SQLProcedures` API, and includes a comprehensive suite of tests to ensure correct behavior across various scenarios. **New feature: Procedures metadata retrieval** * Added `procedures()` method to the `cursor` class in `mssql_python/cursor.py` to fetch information about stored procedures, handling both temporary (using direct queries) and permanent procedures (using the ODBC API). The method supports filtering by procedure name, catalog, and schema, and returns detailed metadata for each procedure. **ODBC bindings and pybind integration** * Introduced new function pointer type `SQLProceduresFunc` and related extern variable to the ODBC bindings header and implementation files (`mssql_python/pybind/ddbc_bindings.h`, `mssql_python/pybind/ddbc_bindings.cpp`). * Loaded the `SQLProceduresW` function pointer during driver initialization and included it in the required function check. * Implemented a wrapper function `SQLProcedures_wrap` and exposed it to Python as `DDBCSQLProcedures` for cross-platform use. **Testing and validation** * Added a comprehensive set of tests to `tests/test_004_cursor.py` to verify the new `procedures()` functionality, including filtering by name, schema, and catalog, handling of input/output parameters, result set reporting, and correct behavior with non-existent procedures. --------- Co-authored-by: Jahnvi Thakkar --- mssql_python/constants.py | 9 + mssql_python/cursor.py | 854 +++++++++++- mssql_python/pybind/ddbc_bindings.cpp | 345 ++++- mssql_python/pybind/ddbc_bindings.h | 22 + tests/test_004_cursor.py | 1831 +++++++++++++++++++++++++ 5 files changed, 3059 insertions(+), 2 deletions(-) diff --git a/mssql_python/constants.py b/mssql_python/constants.py index a4e0c7072..e63fbd1b5 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -117,6 +117,15 @@ class ConstantsDDBC(Enum): SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 SQL_ATTR_QUERY_TIMEOUT = 0 + SQL_SCOPE_CURROW = 0 + SQL_BEST_ROWID = 1 + SQL_ROWVER = 2 + SQL_NO_NULLS = 0 + SQL_NULLABLE_UNKNOWN = 2 + SQL_INDEX_UNIQUE = 0 + SQL_INDEX_ALL = 1 + SQL_QUICK = 0 + SQL_ENSURE = 1 class AuthType(Enum): """Constants for authentication types""" diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 41d633dce..a3e6d0ccc 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -16,7 +16,7 @@ from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings -from mssql_python.exceptions import InterfaceError +from mssql_python.exceptions import InterfaceError, ProgrammingError from mssql_python.row import Row from mssql_python import get_settings @@ -46,6 +46,16 @@ class Cursor: setoutputsize(size, column=None) -> None. """ + # TODO(jathakkar): Thread safety considerations + # The cursor class contains methods that are not thread-safe due to: + # 1. Methods that mutate cursor state (_reset_cursor, self.description, etc.) + # 2. Methods that call ODBC functions with shared handles (self.hstmt) + # + # These methods should be properly synchronized or redesigned when implementing + # async functionality to prevent race conditions and data corruption. + # Consider using locks, redesigning for immutability, or ensuring + # cursor objects are never shared across threads. + def __init__(self, connection, timeout: int = 0) -> None: """ Initialize the cursor with a database connection. @@ -807,6 +817,848 @@ def fetchall_with_mapping(): # Always reset the cursor on exception self._reset_cursor() raise e + + def procedures(self, procedure=None, catalog=None, schema=None): + """ + Executes SQLProcedures and creates a result set of information about procedures in the data source. + + Args: + procedure (str, optional): Procedure name pattern. Default is None (all procedures). + catalog (str, optional): Catalog name pattern. Default is None (current catalog). + schema (str, optional): Schema name pattern. Default is None (all schemas). + + Returns: + List of Row objects, each containing procedure information with these columns: + - procedure_cat (str): The catalog name + - procedure_schem (str): The schema name + - procedure_name (str): The procedure name + - num_input_params (int): Number of input parameters + - num_output_params (int): Number of output parameters + - num_result_sets (int): Number of result sets + - remarks (str): Comments about the procedure + - procedure_type (int): Type of procedure (1=procedure, 2=function) + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Call the SQLProcedures function + retcode = ddbc_bindings.DDBCSQLProcedures(self.hstmt, catalog, schema, procedure) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Create column metadata and initialize description + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description + column_types = [str, str, str, int, int, int, str, int] + self.description = [ + ("procedure_cat", column_types[0], None, 128, 128, 0, True), + ("procedure_schem", column_types[1], None, 128, 128, 0, True), + ("procedure_name", column_types[2], None, 128, 128, 0, False), + ("num_input_params", column_types[3], None, 10, 10, 0, True), + ("num_output_params", column_types[4], None, 10, 10, 0, True), + ("num_result_sets", column_types[5], None, 10, 10, 0, True), + ("remarks", column_types[6], None, 254, 254, 0, True), + ("procedure_type", column_types[7], None, 10, 10, 0, False) + ] + + # Define column names in ODBC standard order + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def primaryKeys(self, table, catalog=None, schema=None): + """ + Creates a result set of column names that make up the primary key for a table + by executing the SQLPrimaryKeys function. + + Args: + table (str): The name of the table + catalog (str, optional): The catalog name (database). Defaults to None. + schema (str, optional): The schema name. Defaults to None. + + Returns: + list: A list of rows with the following columns: + - table_cat: Catalog name + - table_schem: Schema name + - table_name: Table name + - column_name: Column name that is part of the primary key + - key_seq: Column sequence number in the primary key (starting with 1) + - pk_name: Primary key name + + Raises: + ProgrammingError: If the cursor is closed + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Call the SQLPrimaryKeys function + retcode = ddbc_bindings.DDBCSQLPrimaryKeys( + self.hstmt, + catalog, + schema, + table + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("column_name", column_types[3], None, 128, 128, 0, False), + ("key_seq", column_types[4], None, 10, 10, 0, False), + ("pk_name", column_types[5], None, 128, 128, 0, True) + ] + + # Define column names in ODBC standard order + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def foreignKeys(self, table=None, catalog=None, schema=None, foreignTable=None, foreignCatalog=None, foreignSchema=None): + """ + Executes the SQLForeignKeys function and creates a result set of column names that are foreign keys. + + This function returns: + 1. Foreign keys in the specified table that reference primary keys in other tables, OR + 2. Foreign keys in other tables that reference the primary key in the specified table + + Args: + table (str, optional): The table containing the foreign key columns + catalog (str, optional): The catalog containing table + schema (str, optional): The schema containing table + foreignTable (str, optional): The table containing the primary key columns + foreignCatalog (str, optional): The catalog containing foreignTable + foreignSchema (str, optional): The schema containing foreignTable + + Returns: + List of Row objects, each containing foreign key information with these columns: + - pktable_cat (str): Primary key table catalog name + - pktable_schem (str): Primary key table schema name + - pktable_name (str): Primary key table name + - pkcolumn_name (str): Primary key column name + - fktable_cat (str): Foreign key table catalog name + - fktable_schem (str): Foreign key table schema name + - fktable_name (str): Foreign key table name + - fkcolumn_name (str): Foreign key column name + - key_seq (int): Sequence number of the column in the foreign key + - update_rule (int): Action for update (CASCADE, SET NULL, etc.) + - delete_rule (int): Action for delete (CASCADE, SET NULL, etc.) + - fk_name (str): Foreign key name + - pk_name (str): Primary key name + - deferrability (int): Indicates if constraint checking can be deferred + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Check if we have at least one table specified - mimic pyodbc behavior + if table is None and foreignTable is None: + raise ProgrammingError("Either table or foreignTable must be specified", "HY000") + + # Call the SQLForeignKeys function + retcode = ddbc_bindings.DDBCSQLForeignKeys( + self.hstmt, + foreignCatalog, foreignSchema, foreignTable, + catalog, schema, table + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, str, str, str, str, int, int, int, str, str, int] + self.description = [ + ("pktable_cat", column_types[0], None, 128, 128, 0, True), + ("pktable_schem", column_types[1], None, 128, 128, 0, True), + ("pktable_name", column_types[2], None, 128, 128, 0, False), + ("pkcolumn_name", column_types[3], None, 128, 128, 0, False), + ("fktable_cat", column_types[4], None, 128, 128, 0, True), + ("fktable_schem", column_types[5], None, 128, 128, 0, True), + ("fktable_name", column_types[6], None, 128, 128, 0, False), + ("fkcolumn_name", column_types[7], None, 128, 128, 0, False), + ("key_seq", column_types[8], None, 10, 10, 0, False), + ("update_rule", column_types[9], None, 10, 10, 0, False), + ("delete_rule", column_types[10], None, 10, 10, 0, False), + ("fk_name", column_types[11], None, 128, 128, 0, True), + ("pk_name", column_types[12], None, 128, 128, 0, True), + ("deferrability", column_types[13], None, 10, 10, 0, False) + ] + + # Define column names in ODBC standard order + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def rowIdColumns(self, table, catalog=None, schema=None, nullable=True): + """ + Executes SQLSpecialColumns with SQL_BEST_ROWID which creates a result set of + columns that uniquely identify a row. + + Args: + table (str): The table name + catalog (str, optional): The catalog name (database). Defaults to None. + schema (str, optional): The schema name. Defaults to None. + nullable (bool, optional): Whether to include nullable columns. Defaults to True. + + Returns: + list: A list of rows with the following columns: + - scope: One of SQL_SCOPE_CURROW, SQL_SCOPE_TRANSACTION, or SQL_SCOPE_SESSION + - column_name: Column name + - data_type: The ODBC SQL data type constant (e.g. SQL_CHAR) + - type_name: Type name + - column_size: Column size + - buffer_length: Buffer length + - decimal_digits: Decimal digits + - pseudo_column: One of SQL_PC_UNKNOWN, SQL_PC_NOT_PSEUDO, SQL_PC_PSEUDO + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Convert None values to empty strings as required by ODBC API + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Set the identifier type to SQL_BEST_ROWID (1) + identifier_type = ddbc_sql_const.SQL_BEST_ROWID.value + + # Set scope to SQL_SCOPE_CURROW (0) - default scope + scope = ddbc_sql_const.SQL_SCOPE_CURROW.value + + # Set nullable flag + nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value + + # Call the SQLSpecialColumns function + retcode = ddbc_bindings.DDBCSQLSpecialColumns( + self.hstmt, + identifier_type, + catalog, + schema, + table, + scope, + nullable_flag + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [int, str, int, str, int, int, int, int] + self.description = [ + ("scope", column_types[0], None, 10, 10, 0, False), + ("column_name", column_types[1], None, 128, 128, 0, False), + ("data_type", column_types[2], None, 10, 10, 0, False), + ("type_name", column_types[3], None, 128, 128, 0, False), + ("column_size", column_types[4], None, 10, 10, 0, False), + ("buffer_length", column_types[5], None, 10, 10, 0, False), + ("decimal_digits", column_types[6], None, 10, 10, 0, True), + ("pseudo_column", column_types[7], None, 10, 10, 0, False) + ] + + # Create a column map with both ODBC standard names and lowercase aliases + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def rowVerColumns(self, table, catalog=None, schema=None, nullable=True): + """ + Executes SQLSpecialColumns with SQL_ROWVER which creates a result set of + columns that are automatically updated when any value in the row is updated. + + Args: + table (str): The table name + catalog (str, optional): The catalog name (database). Defaults to None. + schema (str, optional): The schema name. Defaults to None. + nullable (bool, optional): Whether to include nullable columns. Defaults to True. + + Returns: + list: A list of rows with the following columns: + - scope: One of SQL_SCOPE_CURROW, SQL_SCOPE_TRANSACTION, or SQL_SCOPE_SESSION + - column_name: Column name + - data_type: The ODBC SQL data type constant (e.g. SQL_CHAR) + - type_name: Type name + - column_size: Column size + - buffer_length: Buffer length + - decimal_digits: Decimal digits + - pseudo_column: One of SQL_PC_UNKNOWN, SQL_PC_NOT_PSEUDO, SQL_PC_PSEUDO + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Set the identifier type to SQL_ROWVER (2) + identifier_type = ddbc_sql_const.SQL_ROWVER.value + + # Set scope to SQL_SCOPE_CURROW (0) - default scope + scope = ddbc_sql_const.SQL_SCOPE_CURROW.value + + # Set nullable flag + nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value + + # Call the SQLSpecialColumns function + retcode = ddbc_bindings.DDBCSQLSpecialColumns( + self.hstmt, + identifier_type, + catalog, + schema, + table, + scope, + nullable_flag + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [int, str, int, str, int, int, int, int] + self.description = [ + ("scope", column_types[0], None, 10, 10, 0, False), + ("column_name", column_types[1], None, 128, 128, 0, False), + ("data_type", column_types[2], None, 10, 10, 0, False), + ("type_name", column_types[3], None, 128, 128, 0, False), + ("column_size", column_types[4], None, 10, 10, 0, False), + ("buffer_length", column_types[5], None, 10, 10, 0, False), + ("decimal_digits", column_types[6], None, 10, 10, 0, True), + ("pseudo_column", column_types[7], None, 10, 10, 0, False) + ] + + # Create a column map with both ODBC standard names and lowercase aliases + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def statistics(self, table: str, catalog: str = None, schema: str = None, unique: bool = False, quick: bool = True) -> 'Cursor': + """ + Creates a result set of statistics about a single table and the indexes associated + with the table by executing SQLStatistics. + + Args: + table (str): The name of the table. + catalog (str, optional): The catalog name. Defaults to None. + schema (str, optional): The schema name. Defaults to None. + unique (bool, optional): If True, only unique indexes are returned. + If False, all indexes are returned. Defaults to False. + quick (bool, optional): If True, CARDINALITY and PAGES are returned only + if readily available. Defaults to True. + + Returns: + cursor: The cursor itself, containing the result set. Use fetchone(), fetchmany(), + or fetchall() to retrieve the results. + + Example: + # Get statistics for the 'Customers' table + stats_cursor = cursor.statistics(table='Customers') + + # Fetch rows as needed + first_stat = stats_cursor.fetchone() + next_10_stats = stats_cursor.fetchmany(10) + all_remaining = stats_cursor.fetchall() + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Table name is required + if not table: + raise ProgrammingError("Table name is required", "HY000") + + # Set unique flag (SQL_INDEX_UNIQUE = 0, SQL_INDEX_ALL = 1) + unique_option = ddbc_sql_const.SQL_INDEX_UNIQUE.value if unique else ddbc_sql_const.SQL_INDEX_ALL.value + + # Set quick flag (SQL_QUICK = 0, SQL_ENSURE = 1) + reserved_option = ddbc_sql_const.SQL_QUICK.value if quick else ddbc_sql_const.SQL_ENSURE.value + + # Call the SQLStatistics function + retcode = ddbc_bindings.DDBCSQLStatistics( + self.hstmt, + catalog, + schema, + table, + unique_option, + reserved_option + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, bool, str, str, int, int, str, str, int, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("non_unique", column_types[3], None, 1, 1, 0, False), + ("index_qualifier", column_types[4], None, 128, 128, 0, True), + ("index_name", column_types[5], None, 128, 128, 0, True), + ("type", column_types[6], None, 10, 10, 0, False), + ("ordinal_position", column_types[7], None, 10, 10, 0, False), + ("column_name", column_types[8], None, 128, 128, 0, True), + ("asc_or_desc", column_types[9], None, 1, 1, 0, True), + ("cardinality", column_types[10], None, 20, 20, 0, True), + ("pages", column_types[11], None, 20, 20, 0, True), + ("filter_condition", column_types[12], None, 128, 128, 0, True) + ] + + # Create a column map with both ODBC standard names and lowercase aliases + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + return self + + def columns(self, table=None, catalog=None, schema=None, column=None): + """ + Creates a result set of column information in the specified tables + using the SQLColumns function. + + Args: + table (str, optional): The table name pattern. Default is None (all tables). + catalog (str, optional): The catalog name. Default is None (current catalog). + schema (str, optional): The schema name pattern. Default is None (all schemas). + column (str, optional): The column name pattern. Default is None (all columns). + + Returns: + cursor: The cursor itself, containing the result set. Use fetchone(), fetchmany(), + or fetchall() to retrieve the results. + + Each row contains the following columns: + - table_cat (str): Catalog name + - table_schem (str): Schema name + - table_name (str): Table name + - column_name (str): Column name + - data_type (int): The ODBC SQL data type constant (e.g. SQL_CHAR) + - type_name (str): Data source dependent type name + - column_size (int): Column size + - buffer_length (int): Length of the column in bytes + - decimal_digits (int): Number of fractional digits + - num_prec_radix (int): Radix (typically 10 or 2) + - nullable (int): One of SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE_UNKNOWN + - remarks (str): Comments about the column + - column_def (str): Default value for the column + - sql_data_type (int): The SQL data type from java.sql.Types + - sql_datetime_sub (int): Subcode for datetime types + - char_octet_length (int): Maximum length in bytes for char types + - ordinal_position (int): Column position in the table (starting at 1) + - is_nullable (str): "YES", "NO", or "" (unknown) + + Warning: + Calling this method without any filters (all parameters as None) will enumerate + EVERY column in EVERY table in the database. This can be extremely expensive in + large databases, potentially causing high memory usage, slow execution times, + and in extreme cases, timeout errors. Always use filters (catalog, schema, table, + or column) whenever possible to limit the result set. + + Example: + # Get all columns in table 'Customers' + columns = cursor.columns(table='Customers') + + # Get all columns in table 'Customers' in schema 'dbo' + columns = cursor.columns(table='Customers', schema='dbo') + + # Get column named 'CustomerID' in any table + columns = cursor.columns(column='CustomerID') + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Call the SQLColumns function + retcode = ddbc_bindings.DDBCSQLColumns( + self.hstmt, + catalog, + schema, + table, + column + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, int, str, int, int, int, int, int, str, str, int, int, int, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("column_name", column_types[3], None, 128, 128, 0, False), + ("data_type", column_types[4], None, 10, 10, 0, False), + ("type_name", column_types[5], None, 128, 128, 0, False), + ("column_size", column_types[6], None, 10, 10, 0, True), + ("buffer_length", column_types[7], None, 10, 10, 0, True), + ("decimal_digits", column_types[8], None, 10, 10, 0, True), + ("num_prec_radix", column_types[9], None, 10, 10, 0, True), + ("nullable", column_types[10], None, 10, 10, 0, False), + ("remarks", column_types[11], None, 254, 254, 0, True), + ("column_def", column_types[12], None, 254, 254, 0, True), + ("sql_data_type", column_types[13], None, 10, 10, 0, False), + ("sql_datetime_sub", column_types[14], None, 10, 10, 0, True), + ("char_octet_length", column_types[15], None, 10, 10, 0, True), + ("ordinal_position", column_types[16], None, 10, 10, 0, False), + ("is_nullable", column_types[17], None, 254, 254, 0, True) + ] + + # Store the column mappings for this specific columns() call + column_names = [desc[0] for desc in self.description] + + # Create a specialized column map for this result set + columns_map = {} + for i, name in enumerate(column_names): + columns_map[name] = i + columns_map[name.lower()] = i + + # Define wrapped fetch methods that preserve existing column mapping + # but add our specialized mapping just for column results + def fetchone_with_columns_mapping(): + row = self._original_fetchone() + if row is not None: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return row + + def fetchmany_with_columns_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return rows + + def fetchall_with_columns_mapping(): + rows = self._original_fetchall() + for row in rows: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return rows + + # Save original fetch methods + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Override fetch methods with our wrapped versions + self.fetchone = fetchone_with_columns_mapping + self.fetchmany = fetchmany_with_columns_mapping + self.fetchall = fetchall_with_columns_mapping + + return self @staticmethod def _select_best_sample_value(column): diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 2d176b7a8..f3aed22ad 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -124,6 +124,12 @@ SQLDescribeColFunc SQLDescribeCol_ptr = nullptr; SQLMoreResultsFunc SQLMoreResults_ptr = nullptr; SQLColAttributeFunc SQLColAttribute_ptr = nullptr; SQLGetTypeInfoFunc SQLGetTypeInfo_ptr = nullptr; +SQLProceduresFunc SQLProcedures_ptr = nullptr; +SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr; +SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr; +SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr; +SQLStatisticsFunc SQLStatistics_ptr = nullptr; +SQLColumnsFunc SQLColumns_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -781,6 +787,12 @@ DriverHandle LoadDriverOrThrowException() { SQLMoreResults_ptr = GetFunctionPointer(handle, "SQLMoreResults"); SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); SQLGetTypeInfo_ptr = GetFunctionPointer(handle, "SQLGetTypeInfoW"); + SQLProcedures_ptr = GetFunctionPointer(handle, "SQLProceduresW"); + SQLForeignKeys_ptr = GetFunctionPointer(handle, "SQLForeignKeysW"); + SQLPrimaryKeys_ptr = GetFunctionPointer(handle, "SQLPrimaryKeysW"); + SQLSpecialColumns_ptr = GetFunctionPointer(handle, "SQLSpecialColumnsW"); + SQLStatistics_ptr = GetFunctionPointer(handle, "SQLStatisticsW"); + SQLColumns_ptr = GetFunctionPointer(handle, "SQLColumnsW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -799,7 +811,9 @@ DriverHandle LoadDriverOrThrowException() { SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr && - SQLGetTypeInfo_ptr; + SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr && + SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr && + SQLColumns_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -872,6 +886,236 @@ SQLRETURN SQLGetTypeInfo_Wrapper(SqlHandlePtr StatementHandle, SQLSMALLINT DataT return SQLGetTypeInfo_ptr(StatementHandle->get(), DataType); } +SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const py::object& procedureObj) { + if (!SQLProcedures_ptr) { + ThrowStdException("SQLProcedures function not loaded"); + } + + std::wstring catalog = py::isinstance(catalogObj) ? L"" : catalogObj.cast(); + std::wstring schema = py::isinstance(schemaObj) ? L"" : schemaObj.cast(); + std::wstring procedure = py::isinstance(procedureObj) ? L"" : procedureObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector procedureBuf = WStringToSQLWCHAR(procedure); + + return SQLProcedures_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + procedure.empty() ? nullptr : procedureBuf.data(), + procedure.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLProcedures_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + procedure.empty() ? nullptr : (SQLWCHAR*)procedure.c_str(), + procedure.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& pkCatalogObj, + const py::object& pkSchemaObj, + const py::object& pkTableObj, + const py::object& fkCatalogObj, + const py::object& fkSchemaObj, + const py::object& fkTableObj) { + if (!SQLForeignKeys_ptr) { + ThrowStdException("SQLForeignKeys function not loaded"); + } + + std::wstring pkCatalog = py::isinstance(pkCatalogObj) ? L"" : pkCatalogObj.cast(); + std::wstring pkSchema = py::isinstance(pkSchemaObj) ? L"" : pkSchemaObj.cast(); + std::wstring pkTable = py::isinstance(pkTableObj) ? L"" : pkTableObj.cast(); + std::wstring fkCatalog = py::isinstance(fkCatalogObj) ? L"" : fkCatalogObj.cast(); + std::wstring fkSchema = py::isinstance(fkSchemaObj) ? L"" : fkSchemaObj.cast(); + std::wstring fkTable = py::isinstance(fkTableObj) ? L"" : fkTableObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector pkCatalogBuf = WStringToSQLWCHAR(pkCatalog); + std::vector pkSchemaBuf = WStringToSQLWCHAR(pkSchema); + std::vector pkTableBuf = WStringToSQLWCHAR(pkTable); + std::vector fkCatalogBuf = WStringToSQLWCHAR(fkCatalog); + std::vector fkSchemaBuf = WStringToSQLWCHAR(fkSchema); + std::vector fkTableBuf = WStringToSQLWCHAR(fkTable); + + return SQLForeignKeys_ptr( + StatementHandle->get(), + pkCatalog.empty() ? nullptr : pkCatalogBuf.data(), + pkCatalog.empty() ? 0 : SQL_NTS, + pkSchema.empty() ? nullptr : pkSchemaBuf.data(), + pkSchema.empty() ? 0 : SQL_NTS, + pkTable.empty() ? nullptr : pkTableBuf.data(), + pkTable.empty() ? 0 : SQL_NTS, + fkCatalog.empty() ? nullptr : fkCatalogBuf.data(), + fkCatalog.empty() ? 0 : SQL_NTS, + fkSchema.empty() ? nullptr : fkSchemaBuf.data(), + fkSchema.empty() ? 0 : SQL_NTS, + fkTable.empty() ? nullptr : fkTableBuf.data(), + fkTable.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLForeignKeys_ptr( + StatementHandle->get(), + pkCatalog.empty() ? nullptr : (SQLWCHAR*)pkCatalog.c_str(), + pkCatalog.empty() ? 0 : SQL_NTS, + pkSchema.empty() ? nullptr : (SQLWCHAR*)pkSchema.c_str(), + pkSchema.empty() ? 0 : SQL_NTS, + pkTable.empty() ? nullptr : (SQLWCHAR*)pkTable.c_str(), + pkTable.empty() ? 0 : SQL_NTS, + fkCatalog.empty() ? nullptr : (SQLWCHAR*)fkCatalog.c_str(), + fkCatalog.empty() ? 0 : SQL_NTS, + fkSchema.empty() ? nullptr : (SQLWCHAR*)fkSchema.c_str(), + fkSchema.empty() ? 0 : SQL_NTS, + fkTable.empty() ? nullptr : (SQLWCHAR*)fkTable.c_str(), + fkTable.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table) { + if (!SQLPrimaryKeys_ptr) { + ThrowStdException("SQLPrimaryKeys function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLPrimaryKeys_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLPrimaryKeys_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, + SQLUSMALLINT unique, + SQLUSMALLINT reserved) { + if (!SQLStatistics_ptr) { + ThrowStdException("SQLStatistics function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLStatistics_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, + unique, + reserved); +#else + // Windows implementation + return SQLStatistics_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS, + unique, + reserved); +#endif +} + +SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const py::object& tableObj, + const py::object& columnObj) { + if (!SQLColumns_ptr) { + ThrowStdException("SQLColumns function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalogStr = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schemaStr = schemaObj.is_none() ? L"" : schemaObj.cast(); + std::wstring tableStr = tableObj.is_none() ? L"" : tableObj.cast(); + std::wstring columnStr = columnObj.is_none() ? L"" : columnObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalogStr); + std::vector schemaBuf = WStringToSQLWCHAR(schemaStr); + std::vector tableBuf = WStringToSQLWCHAR(tableStr); + std::vector columnBuf = WStringToSQLWCHAR(columnStr); + + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : catalogBuf.data(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : schemaBuf.data(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : tableBuf.data(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : columnBuf.data(), + columnStr.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : (SQLWCHAR*)catalogStr.c_str(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : (SQLWCHAR*)schemaStr.c_str(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : (SQLWCHAR*)tableStr.c_str(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : (SQLWCHAR*)columnStr.c_str(), + columnStr.empty() ? 0 : SQL_NTS); +#endif +} + // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -1429,6 +1673,54 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta return SQL_SUCCESS; } +SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, + SQLSMALLINT identifierType, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, + SQLSMALLINT scope, + SQLSMALLINT nullable) { + if (!SQLSpecialColumns_ptr) { + ThrowStdException("SQLSpecialColumns function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLSpecialColumns_ptr( + StatementHandle->get(), + identifierType, + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, + scope, + nullable); +#else + // Windows implementation + return SQLSpecialColumns_ptr( + StatementHandle->get(), + identifierType, + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS, + scope, + nullable); +#endif +} + // Wrap SQLFetch to retrieve rows SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { LOG("Fetch next row"); @@ -2592,6 +2884,57 @@ PYBIND11_MODULE(ddbc_bindings, m) { }, "Set statement attributes"); m.def("DDBCSQLGetTypeInfo", &SQLGetTypeInfo_Wrapper, "Returns information about the data types that are supported by the data source", py::arg("StatementHandle"), py::arg("DataType")); + m.def("DDBCSQLProcedures", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const py::object& procedure) { + return SQLProcedures_wrap(StatementHandle, catalog, schema, procedure); + }); + + m.def("DDBCSQLForeignKeys", [](SqlHandlePtr StatementHandle, + const py::object& pkCatalog, + const py::object& pkSchema, + const py::object& pkTable, + const py::object& fkCatalog, + const py::object& fkSchema, + const py::object& fkTable) { + return SQLForeignKeys_wrap(StatementHandle, + pkCatalog, pkSchema, pkTable, + fkCatalog, fkSchema, fkTable); + }); + m.def("DDBCSQLPrimaryKeys", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const std::wstring& table) { + return SQLPrimaryKeys_wrap(StatementHandle, catalog, schema, table); + }); + m.def("DDBCSQLSpecialColumns", [](SqlHandlePtr StatementHandle, + SQLSMALLINT identifierType, + const py::object& catalog, + const py::object& schema, + const std::wstring& table, + SQLSMALLINT scope, + SQLSMALLINT nullable) { + return SQLSpecialColumns_wrap(StatementHandle, + identifierType, catalog, schema, table, + scope, nullable); + }); + m.def("DDBCSQLStatistics", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const std::wstring& table, + SQLUSMALLINT unique, + SQLUSMALLINT reserved) { + return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved); + }); + m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const py::object& table, + const py::object& column) { + return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); + }); + // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 98d9f9f5b..d757ad954 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -106,6 +106,22 @@ typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); typedef SQLRETURN (SQL_API* SQLGetTypeInfoFunc)(SQLHSTMT, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLProceduresFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLForeignKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLPrimaryKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT); // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -150,6 +166,12 @@ extern SQLDescribeColFunc SQLDescribeCol_ptr; extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; extern SQLGetTypeInfoFunc SQLGetTypeInfo_ptr; +extern SQLProceduresFunc SQLProcedures_ptr; +extern SQLForeignKeysFunc SQLForeignKeys_ptr; +extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr; +extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr; +extern SQLStatisticsFunc SQLStatistics_ptr; +extern SQLColumnsFunc SQLColumns_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 619cc2304..6dcb709b4 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -10,6 +10,7 @@ import pytest from datetime import datetime, date, time +import time as time_module import decimal from mssql_python import Connection import mssql_python @@ -1907,6 +1908,1836 @@ def test_gettypeinfo_cached_results(cursor): for row in second_result: assert row.data_type == ConstantsDDBC.SQL_VARCHAR.value, \ f"Expected SQL_VARCHAR type, got {row.data_type}" + +def test_procedures_setup(cursor, db_connection): + """Create a test schema and procedures for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_proc_schema') EXEC('CREATE SCHEMA pytest_proc_schema')") + + # Create test stored procedures + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_proc1 + AS + BEGIN + SELECT 1 AS result + END + """) + + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_proc2 + @param1 INT, + @param2 VARCHAR(50) OUTPUT + AS + BEGIN + SELECT @param2 = 'Output ' + CAST(@param1 AS VARCHAR(10)) + RETURN @param1 + END + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_procedures_all(cursor, db_connection): + """Test getting information about all procedures""" + # First set up our test procedures + test_procedures_setup(cursor, db_connection) + + try: + # Get all procedures + procs = cursor.procedures().fetchall() + + # Verify we got results + assert procs is not None, "procedures() should return results" + assert len(procs) > 0, "procedures() should return at least one procedure" + + # Verify structure of results + first_row = procs[0] + assert hasattr(first_row, 'procedure_cat'), "Result should have procedure_cat column" + assert hasattr(first_row, 'procedure_schem'), "Result should have procedure_schem column" + assert hasattr(first_row, 'procedure_name'), "Result should have procedure_name column" + assert hasattr(first_row, 'num_input_params'), "Result should have num_input_params column" + assert hasattr(first_row, 'num_output_params'), "Result should have num_output_params column" + assert hasattr(first_row, 'num_result_sets'), "Result should have num_result_sets column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + assert hasattr(first_row, 'procedure_type'), "Result should have procedure_type column" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_specific(cursor, db_connection): + """Test getting information about a specific procedure""" + try: + # Get specific procedure + procs = cursor.procedures(procedure='test_proc1', schema='pytest_proc_schema').fetchall() + + # Verify we got the correct procedure + assert len(procs) == 1, "Should find exactly one procedure" + proc = procs[0] + assert proc.procedure_name == 'test_proc1;1', "Wrong procedure name returned" + assert proc.procedure_schem == 'pytest_proc_schema', "Wrong schema returned" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_with_schema(cursor, db_connection): + """Test getting procedures with schema filter""" + try: + # Get procedures for our test schema + procs = cursor.procedures(schema='pytest_proc_schema').fetchall() + + # Verify schema filter worked + assert len(procs) >= 2, "Should find at least two procedures in schema" + for proc in procs: + assert proc.procedure_schem == 'pytest_proc_schema', f"Expected schema pytest_proc_schema, got {proc.procedure_schem}" + + # Verify our specific procedures are in the results + proc_names = [p.procedure_name for p in procs] + assert 'test_proc1;1' in proc_names, "test_proc1;1 should be in results" + assert 'test_proc2;1' in proc_names, "test_proc2;1 should be in results" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_nonexistent(cursor): + """Test procedures() with non-existent procedure name""" + # Use a procedure name that's highly unlikely to exist + procs = cursor.procedures(procedure='nonexistent_procedure_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(procs, list), "Should return a list for non-existent procedure" + assert len(procs) == 0, "Should return empty list for non-existent procedure" + +def test_procedures_catalog_filter(cursor, db_connection): + """Test procedures() with catalog filter""" + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + try: + # Get procedures with current catalog + procs = cursor.procedures(catalog=current_db, schema='pytest_proc_schema').fetchall() + + # Verify catalog filter worked + assert len(procs) >= 2, "Should find procedures in current catalog" + for proc in procs: + assert proc.procedure_cat == current_db, f"Expected catalog {current_db}, got {proc.procedure_cat}" + + # Get procedures with non-existent catalog + fake_procs = cursor.procedures(catalog='nonexistent_db_xyz123').fetchall() + assert len(fake_procs) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_with_parameters(cursor, db_connection): + """Test that procedures() correctly reports parameter information""" + try: + # Create a simpler procedure with basic parameters + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_params_proc + @in1 INT, + @in2 VARCHAR(50) + AS + BEGIN + SELECT @in1 AS value1, @in2 AS value2 + END + """) + db_connection.commit() + + # Get procedure info + procs = cursor.procedures(procedure='test_params_proc', schema='pytest_proc_schema').fetchall() + + # Verify we found the procedure + assert len(procs) == 1, "Should find exactly one procedure" + proc = procs[0] + + # Just check if columns exist, don't check specific values + assert hasattr(proc, 'num_input_params'), "Result should have num_input_params column" + assert hasattr(proc, 'num_output_params'), "Result should have num_output_params column" + + # Test simple execution without output parameters + cursor.execute("EXEC pytest_proc_schema.test_params_proc 10, 'Test'") + + # Verify the procedure returned expected values + row = cursor.fetchone() + assert row is not None, "Procedure should return results" + assert row[0] == 10, "First parameter value incorrect" + assert row[1] == 'Test', "Second parameter value incorrect" + + finally: + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_params_proc") + db_connection.commit() + +def test_procedures_result_set_info(cursor, db_connection): + """Test that procedures() reports information about result sets""" + try: + # Create procedures with different result set patterns + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_no_results + AS + BEGIN + DECLARE @x INT = 1 + END + """) + + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_one_result + AS + BEGIN + SELECT 1 AS col1, 'test' AS col2 + END + """) + + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_multiple_results + AS + BEGIN + SELECT 1 AS result1 + SELECT 'test' AS result2 + SELECT GETDATE() AS result3 + END + """) + db_connection.commit() + + # Get procedure info for all test procedures + procs = cursor.procedures(schema='pytest_proc_schema', procedure='test_%').fetchall() + + # Verify we found at least some procedures + assert len(procs) > 0, "Should find at least some test procedures" + + # Get the procedure names we found + result_proc_names = [p.procedure_name for p in procs + if p.procedure_name.startswith('test_') and 'results' in p.procedure_name] + print(f"Found result procedures: {result_proc_names}") + + # The num_result_sets column exists but might not have correct values + for proc in procs: + assert hasattr(proc, 'num_result_sets'), "Result should have num_result_sets column" + + # Test execution of the procedures to verify they work + cursor.execute("EXEC pytest_proc_schema.test_no_results") + assert cursor.fetchall() == [], "test_no_results should return no results" + + cursor.execute("EXEC pytest_proc_schema.test_one_result") + rows = cursor.fetchall() + assert len(rows) == 1, "test_one_result should return one row" + assert len(rows[0]) == 2, "test_one_result row should have two columns" + + cursor.execute("EXEC pytest_proc_schema.test_multiple_results") + rows1 = cursor.fetchall() + assert len(rows1) == 1, "First result set should have one row" + assert cursor.nextset(), "Should have a second result set" + rows2 = cursor.fetchall() + assert len(rows2) == 1, "Second result set should have one row" + assert cursor.nextset(), "Should have a third result set" + rows3 = cursor.fetchall() + assert len(rows3) == 1, "Third result set should have one row" + + finally: + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_no_results") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_one_result") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_multiple_results") + db_connection.commit() + +def test_procedures_cleanup(cursor, db_connection): + """Clean up all test procedures and schema after testing""" + try: + # Drop all test procedures + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_proc1") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_proc2") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_params_proc") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_no_results") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_one_result") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_multiple_results") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_proc_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_foreignkeys_setup(cursor, db_connection): + """Create tables with foreign key relationships for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_fk_schema') EXEC('CREATE SCHEMA pytest_fk_schema')") + + # Drop tables if they exist (in reverse order to avoid constraint conflicts) + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + + # Create parent table + cursor.execute(""" + CREATE TABLE pytest_fk_schema.customers ( + customer_id INT PRIMARY KEY, + customer_name VARCHAR(100) NOT NULL + ) + """) + + # Create child table with foreign key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.orders ( + order_id INT PRIMARY KEY, + order_date DATETIME NOT NULL, + customer_id INT NOT NULL, + total_amount DECIMAL(10, 2) NOT NULL, + CONSTRAINT FK_Orders_Customers FOREIGN KEY (customer_id) + REFERENCES pytest_fk_schema.customers (customer_id) + ) + """) + + # Insert test data + cursor.execute(""" + INSERT INTO pytest_fk_schema.customers (customer_id, customer_name) + VALUES (1, 'Test Customer 1'), (2, 'Test Customer 2') + """) + + cursor.execute(""" + INSERT INTO pytest_fk_schema.orders (order_id, order_date, customer_id, total_amount) + VALUES (101, GETDATE(), 1, 150.00), (102, GETDATE(), 2, 250.50) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_foreignkeys_all(cursor, db_connection): + """Test getting all foreign keys""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get all foreign keys + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert fks is not None, "foreignKeys() should return results" + assert len(fks) > 0, "foreignKeys() should return at least one foreign key" + + # Verify our test FK is in the results + # Search case-insensitively since the database might return different case + found_test_fk = False + for fk in fks: + if (fk.fktable_name.lower() == 'orders' and + fk.pktable_name.lower() == 'customers'): + found_test_fk = True + break + + assert found_test_fk, "Could not find the test foreign key in results" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_specific_table(cursor, db_connection): + """Test getting foreign keys for a specific table""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys for the orders table + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) == 1, "Should find exactly one foreign key for orders table" + + # Verify the foreign key details + fk = fks[0] + assert fk.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert fk.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert fk.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert fk.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_specific_foreign_table(cursor, db_connection): + """Test getting foreign keys that reference a specific table""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys that reference the customers table + fks = cursor.foreignKeys(foreignTable='customers', foreignSchema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find at least one foreign key referencing customers table" + + # Verify our test FK is in the results + found_test_fk = False + for fk in fks: + if (fk.fktable_name.lower() == 'orders' and + fk.pktable_name.lower() == 'customers'): + found_test_fk = True + break + + assert found_test_fk, "Could not find the test foreign key in results" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_both_tables(cursor, db_connection): + """Test getting foreign keys with both table and foreignTable specified""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys between the two tables + fks = cursor.foreignKeys( + table='orders', schema='pytest_fk_schema', + foreignTable='customers', foreignSchema='pytest_fk_schema' + ).fetchall() + + # Verify we got results + assert len(fks) == 1, "Should find exactly one foreign key between specified tables" + + # Verify the foreign key details + fk = fks[0] + assert fk.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert fk.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert fk.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert fk.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_nonexistent(cursor): + """Test foreignKeys() with non-existent table name""" + # Use a table name that's highly unlikely to exist + fks = cursor.foreignKeys(table='nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(fks, list), "Should return a list for non-existent table" + assert len(fks) == 0, "Should return empty list for non-existent table" + +def test_foreignkeys_catalog_schema(cursor, db_connection): + """Test foreignKeys() with catalog and schema filters""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + row = cursor.fetchone() + current_db = row.current_db + + # Get foreign keys with current catalog and pytest schema + fks = cursor.foreignKeys( + table='orders', + catalog=current_db, + schema='pytest_fk_schema' + ).fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find foreign keys with correct catalog/schema" + + # Verify catalog/schema in results + for fk in fks: + assert fk.fktable_cat == current_db, "Wrong foreign key table catalog" + assert fk.fktable_schem == 'pytest_fk_schema', "Wrong foreign key table schema" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_result_structure(cursor, db_connection): + """Test the structure of foreignKeys result rows""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys for the orders table + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find at least one foreign key" + + # Check for all required columns in the result + first_row = fks[0] + required_columns = [ + 'pktable_cat', 'pktable_schem', 'pktable_name', 'pkcolumn_name', + 'fktable_cat', 'fktable_schem', 'fktable_name', 'fkcolumn_name', + 'key_seq', 'update_rule', 'delete_rule', 'fk_name', 'pk_name', + 'deferrability' + ] + + for column in required_columns: + assert hasattr(first_row, column), f"Result missing required column: {column}" + + # Verify specific values + assert first_row.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert first_row.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert first_row.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert first_row.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + assert first_row.key_seq == 1, "Wrong key sequence number" + assert first_row.fk_name is not None, "Foreign key name should not be None" + assert first_row.pk_name is not None, "Primary key name should not be None" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_multiple_column_fk(cursor, db_connection): + """Test foreignKeys() with a multi-column foreign key""" + try: + # First create the schema if needed + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_fk_schema') EXEC('CREATE SCHEMA pytest_fk_schema')") + + # Drop tables if they exist (in reverse order to avoid constraint conflicts) + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + + # Create parent table with composite primary key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.product_variants ( + product_id INT NOT NULL, + variant_id INT NOT NULL, + variant_name VARCHAR(100) NOT NULL, + PRIMARY KEY (product_id, variant_id) + ) + """) + + # Create child table with composite foreign key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.order_details ( + order_id INT NOT NULL, + product_id INT NOT NULL, + variant_id INT NOT NULL, + quantity INT NOT NULL, + PRIMARY KEY (order_id, product_id, variant_id), + CONSTRAINT FK_OrderDetails_ProductVariants FOREIGN KEY (product_id, variant_id) + REFERENCES pytest_fk_schema.product_variants (product_id, variant_id) + ) + """) + + db_connection.commit() + + # Get foreign keys for the order_details table + fks = cursor.foreignKeys(table='order_details', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) == 2, "Should find two rows for the composite foreign key (one per column)" + + # Group by key_seq to verify both columns + fk_columns = {} + for fk in fks: + fk_columns[fk.key_seq] = { + 'pkcolumn': fk.pkcolumn_name.lower(), + 'fkcolumn': fk.fkcolumn_name.lower() + } + + # Verify both columns are present + assert 1 in fk_columns, "First column of composite key missing" + assert 2 in fk_columns, "Second column of composite key missing" + + # Verify column mappings + assert fk_columns[1]['pkcolumn'] == 'product_id', "Wrong primary key column 1" + assert fk_columns[1]['fkcolumn'] == 'product_id', "Wrong foreign key column 1" + assert fk_columns[2]['pkcolumn'] == 'variant_id', "Wrong primary key column 2" + assert fk_columns[2]['fkcolumn'] == 'variant_id', "Wrong foreign key column 2" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + db_connection.commit() + +def test_cleanup_schema(cursor, db_connection): + """Clean up the test schema after all tests""" + try: + # Make sure no tables remain + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + db_connection.commit() + + # Drop the schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_fk_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Schema cleanup failed: {e}") + +def test_primarykeys_setup(cursor, db_connection): + """Create tables with primary keys for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_pk_schema') EXEC('CREATE SCHEMA pytest_pk_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.single_pk_test") + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.composite_pk_test") + + # Create table with simple primary key + cursor.execute(""" + CREATE TABLE pytest_pk_schema.single_pk_test ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + description VARCHAR(200) NULL + ) + """) + + # Create table with composite primary key + cursor.execute(""" + CREATE TABLE pytest_pk_schema.composite_pk_test ( + dept_id INT NOT NULL, + emp_id INT NOT NULL, + hire_date DATE NOT NULL, + CONSTRAINT PK_composite_test PRIMARY KEY (dept_id, emp_id) + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_primarykeys_simple(cursor, db_connection): + """Test primaryKeys returns information about a simple primary key""" + try: + # First set up our test tables + test_primarykeys_setup(cursor, db_connection) + + # Get primary key information + pks = cursor.primaryKeys('single_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify we got results + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + + # Verify primary key details + assert pk.table_name.lower() == 'single_pk_test', "Wrong table name" + assert pk.column_name.lower() == 'id', "Wrong primary key column name" + assert pk.key_seq == 1, "Wrong key sequence number" + assert pk.pk_name is not None, "Primary key name should not be None" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_composite(cursor, db_connection): + """Test primaryKeys with a composite primary key""" + try: + # Get primary key information + pks = cursor.primaryKeys('composite_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify we got results for both columns + assert len(pks) == 2, "Should find two primary key columns" + + # Sort by key_seq to ensure consistent order + pks = sorted(pks, key=lambda row: row.key_seq) + + # Verify first column + assert pks[0].table_name.lower() == 'composite_pk_test', "Wrong table name" + assert pks[0].column_name.lower() == 'dept_id', "Wrong first primary key column name" + assert pks[0].key_seq == 1, "Wrong key sequence number for first column" + + # Verify second column + assert pks[1].table_name.lower() == 'composite_pk_test', "Wrong table name" + assert pks[1].column_name.lower() == 'emp_id', "Wrong second primary key column name" + assert pks[1].key_seq == 2, "Wrong key sequence number for second column" + + # Both should have the same PK name + assert pks[0].pk_name == pks[1].pk_name, "Both columns should have the same primary key name" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_column_info(cursor, db_connection): + """Test that primaryKeys returns correct column information""" + try: + # Get primary key information + pks = cursor.primaryKeys('single_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify column information + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + + # Verify expected columns are present + assert hasattr(pk, 'table_cat'), "Result should have table_cat column" + assert hasattr(pk, 'table_schem'), "Result should have table_schem column" + assert hasattr(pk, 'table_name'), "Result should have table_name column" + assert hasattr(pk, 'column_name'), "Result should have column_name column" + assert hasattr(pk, 'key_seq'), "Result should have key_seq column" + assert hasattr(pk, 'pk_name'), "Result should have pk_name column" + + # Verify values are correct + assert pk.table_schem.lower() == 'pytest_pk_schema', "Wrong schema name" + assert pk.table_name.lower() == 'single_pk_test', "Wrong table name" + assert pk.column_name.lower() == 'id', "Wrong column name" + assert isinstance(pk.key_seq, int), "key_seq should be an integer" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_nonexistent(cursor): + """Test primaryKeys() with non-existent table name""" + # Use a table name that's highly unlikely to exist + pks = cursor.primaryKeys('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(pks, list), "Should return a list for non-existent table" + assert len(pks) == 0, "Should return empty list for non-existent table" + +def test_primarykeys_catalog_filter(cursor, db_connection): + """Test primaryKeys() with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get primary keys with current catalog + pks = cursor.primaryKeys('single_pk_test', catalog=current_db, schema='pytest_pk_schema').fetchall() + + # Verify catalog filter worked + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + assert pk.table_cat == current_db, f"Expected catalog {current_db}, got {pk.table_cat}" + + # Get primary keys with non-existent catalog + fake_pks = cursor.primaryKeys('single_pk_test', catalog='nonexistent_db_xyz123').fetchall() + assert len(fake_pks) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.single_pk_test") + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.composite_pk_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_pk_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_specialcolumns_setup(cursor, db_connection): + """Create test tables for testing rowIdColumns and rowVerColumns""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_special_schema') EXEC('CREATE SCHEMA pytest_special_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.rowid_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.timestamp_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.multiple_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.identity_test") + + # Create table with primary key (for rowIdColumns) + cursor.execute(""" + CREATE TABLE pytest_special_schema.rowid_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + unique_col NVARCHAR(100) UNIQUE, + non_unique_col NVARCHAR(100) + ) + """) + + # Create table with rowversion column (for rowVerColumns) + cursor.execute(""" + CREATE TABLE pytest_special_schema.timestamp_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + last_updated ROWVERSION + ) + """) + + # Create table with multiple unique identifiers + cursor.execute(""" + CREATE TABLE pytest_special_schema.multiple_unique_test ( + id INT NOT NULL, + code VARCHAR(10) NOT NULL, + email VARCHAR(100) UNIQUE, + order_number VARCHAR(20) UNIQUE, + CONSTRAINT PK_multiple_unique_test PRIMARY KEY (id, code) + ) + """) + + # Create table with identity column + cursor.execute(""" + CREATE TABLE pytest_special_schema.identity_test ( + id INT IDENTITY(1,1) PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + last_modified DATETIME DEFAULT GETDATE() + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_rowid_columns_basic(cursor, db_connection): + """Test basic functionality of rowIdColumns""" + try: + # Get row identifier columns for simple table + rowid_cols = cursor.rowIdColumns( + table='rowid_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns first column of primary key + assert len(rowid_cols) == 1, "Should find exactly one ROWID column (first column of PK)" + + # Verify column name in the results + col = rowid_cols[0] + assert col.column_name.lower() == 'id', "Primary key column should be included in ROWID results" + + # Verify result structure + assert hasattr(col, 'scope'), "Result should have scope column" + assert hasattr(col, 'column_name'), "Result should have column_name column" + assert hasattr(col, 'data_type'), "Result should have data_type column" + assert hasattr(col, 'type_name'), "Result should have type_name column" + assert hasattr(col, 'column_size'), "Result should have column_size column" + assert hasattr(col, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(col, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(col, 'pseudo_column'), "Result should have pseudo_column column" + + # The scope should be one of the valid values or NULL + assert col.scope in [0, 1, 2, None], f"Invalid scope value: {col.scope}" + + # The pseudo_column should be one of the valid values + assert col.pseudo_column in [0, 1, 2, None], f"Invalid pseudo_column value: {col.pseudo_column}" + + except Exception as e: + pytest.fail(f"rowIdColumns basic test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_identity(cursor, db_connection): + """Test rowIdColumns with identity column""" + try: + # Get row identifier columns for table with identity column + rowid_cols = cursor.rowIdColumns( + table='identity_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns the identity column if it's the primary key + assert len(rowid_cols) == 1, "Should find exactly one ROWID column (identity column as PK)" + + # Verify it's the identity column + col = rowid_cols[0] + assert col.column_name.lower() == 'id', "Identity column should be included as it's the PK" + + except Exception as e: + pytest.fail(f"rowIdColumns identity test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_composite(cursor, db_connection): + """Test rowIdColumns with composite primary key""" + try: + # Get row identifier columns for table with composite primary key + rowid_cols = cursor.rowIdColumns( + table='multiple_unique_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns first column of composite primary key + assert len(rowid_cols) >= 1, "Should find at least one ROWID column (first column of PK)" + + # Verify column names in the results - should be the first PK column + col_names = [col.column_name.lower() for col in rowid_cols] + assert 'id' in col_names, "First part of composite PK should be included" + + # LIMITATION: Other parts of the PK or unique constraints may not be included + if len(rowid_cols) > 1: + # If additional columns are returned, they should be valid + for col in rowid_cols: + assert col.column_name.lower() in ['id', 'code'], "Only PK columns should be returned" + + except Exception as e: + pytest.fail(f"rowIdColumns composite test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_nonexistent(cursor): + """Test rowIdColumns with non-existent table""" + # Use a table name that's highly unlikely to exist + rowid_cols = cursor.rowIdColumns('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(rowid_cols, list), "Should return a list for non-existent table" + assert len(rowid_cols) == 0, "Should return empty list for non-existent table" + +def test_rowid_columns_nullable(cursor, db_connection): + """Test rowIdColumns with nullable parameter""" + try: + # First create a table with nullable unique column and non-nullable PK + cursor.execute(""" + CREATE TABLE pytest_special_schema.nullable_test ( + id INT PRIMARY KEY, -- PK can't be nullable in SQL Server + data NVARCHAR(100) NULL + ) + """) + db_connection.commit() + + # Test with nullable=True (default) + rowid_cols_with_nullable = cursor.rowIdColumns( + table='nullable_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify PK column is included + assert len(rowid_cols_with_nullable) == 1, "Should return exactly one column (PK)" + assert rowid_cols_with_nullable[0].column_name.lower() == 'id', "PK column should be returned" + + # Test with nullable=False + rowid_cols_no_nullable = cursor.rowIdColumns( + table='nullable_test', + schema='pytest_special_schema', + nullable=False + ).fetchall() + + # The behavior of SQLSpecialColumns with SQL_NO_NULLS is to only return + # non-nullable columns that uniquely identify a row, but SQL Server returns + # an empty set in this case - this is expected behavior + assert len(rowid_cols_no_nullable) == 0, "Should return empty list when nullable=False (ODBC API behavior)" + + except Exception as e: + pytest.fail(f"rowIdColumns nullable test failed: {e}") + finally: + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_test") + db_connection.commit() + +def test_rowver_columns_basic(cursor, db_connection): + """Test basic functionality of rowVerColumns""" + try: + # Get version columns from timestamp test table + rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify we got results + assert len(rowver_cols) == 1, "Should find exactly one ROWVER column" + + # Verify the column is the rowversion column + rowver_col = rowver_cols[0] + assert rowver_col.column_name.lower() == 'last_updated', "ROWVER column should be 'last_updated'" + assert rowver_col.type_name.lower() in ['rowversion', 'timestamp'], "ROWVER column should have rowversion or timestamp type" + + # Verify result structure - allowing for NULL values + assert hasattr(rowver_col, 'scope'), "Result should have scope column" + assert hasattr(rowver_col, 'column_name'), "Result should have column_name column" + assert hasattr(rowver_col, 'data_type'), "Result should have data_type column" + assert hasattr(rowver_col, 'type_name'), "Result should have type_name column" + assert hasattr(rowver_col, 'column_size'), "Result should have column_size column" + assert hasattr(rowver_col, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(rowver_col, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(rowver_col, 'pseudo_column'), "Result should have pseudo_column column" + + # The scope should be one of the valid values or NULL + assert rowver_col.scope in [0, 1, 2, None], f"Invalid scope value: {rowver_col.scope}" + + except Exception as e: + pytest.fail(f"rowVerColumns basic test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowver_columns_nonexistent(cursor): + """Test rowVerColumns with non-existent table""" + # Use a table name that's highly unlikely to exist + rowver_cols = cursor.rowVerColumns('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(rowver_cols, list), "Should return a list for non-existent table" + assert len(rowver_cols) == 0, "Should return empty list for non-existent table" + +def test_rowver_columns_nullable(cursor, db_connection): + """Test rowVerColumns with nullable parameter (not expected to have effect)""" + try: + # First create a table with rowversion column + cursor.execute(""" + CREATE TABLE pytest_special_schema.nullable_rowver_test ( + id INT PRIMARY KEY, + ts ROWVERSION + ) + """) + db_connection.commit() + + # Test with nullable=True (default) + rowver_cols_with_nullable = cursor.rowVerColumns( + table='nullable_rowver_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify rowversion column is included (rowversion can't be nullable) + assert len(rowver_cols_with_nullable) == 1, "Should find exactly one ROWVER column" + assert rowver_cols_with_nullable[0].column_name.lower() == 'ts', "ROWVERSION column should be included" + + # Test with nullable=False + rowver_cols_no_nullable = cursor.rowVerColumns( + table='nullable_rowver_test', + schema='pytest_special_schema', + nullable=False + ).fetchall() + + # Verify rowversion column is still included + assert len(rowver_cols_no_nullable) == 1, "Should find exactly one ROWVER column" + assert rowver_cols_no_nullable[0].column_name.lower() == 'ts', "ROWVERSION column should be included even with nullable=False" + + except Exception as e: + pytest.fail(f"rowVerColumns nullable test failed: {e}") + finally: + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_rowver_test") + db_connection.commit() + +def test_specialcolumns_catalog_filter(cursor, db_connection): + """Test special columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Test rowIdColumns with current catalog + rowid_cols = cursor.rowIdColumns( + table='rowid_test', + catalog=current_db, + schema='pytest_special_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(rowid_cols) > 0, "Should find ROWID columns with correct catalog" + + # Test rowIdColumns with non-existent catalog + fake_rowid_cols = cursor.rowIdColumns( + table='rowid_test', + catalog='nonexistent_db_xyz123', + schema='pytest_special_schema' + ).fetchall() + assert len(fake_rowid_cols) == 0, "Should return empty list for non-existent catalog" + + # Test rowVerColumns with current catalog + rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + catalog=current_db, + schema='pytest_special_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(rowver_cols) > 0, "Should find ROWVER columns with correct catalog" + + # Test rowVerColumns with non-existent catalog + fake_rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + catalog='nonexistent_db_xyz123', + schema='pytest_special_schema' + ).fetchall() + assert len(fake_rowver_cols) == 0, "Should return empty list for non-existent catalog" + + except Exception as e: + pytest.fail(f"Special columns catalog filter test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_specialcolumns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.rowid_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.timestamp_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.multiple_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.identity_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_timestamp_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_special_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_statistics_setup(cursor, db_connection): + """Create test tables and indexes for statistics testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_stats_schema') EXEC('CREATE SCHEMA pytest_stats_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.stats_test") + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.empty_stats_test") + + # Create test table with various indexes + cursor.execute(""" + CREATE TABLE pytest_stats_schema.stats_test ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(100) UNIQUE, + department VARCHAR(50) NOT NULL, + salary DECIMAL(10, 2) NULL, + hire_date DATE NOT NULL + ) + """) + + # Create a non-unique index + cursor.execute(""" + CREATE INDEX IX_stats_test_dept_date ON pytest_stats_schema.stats_test (department, hire_date) + """) + + # Create a unique index on multiple columns + cursor.execute(""" + CREATE UNIQUE INDEX UX_stats_test_name_dept ON pytest_stats_schema.stats_test (name, department) + """) + + # Create an empty table for testing + cursor.execute(""" + CREATE TABLE pytest_stats_schema.empty_stats_test ( + id INT PRIMARY KEY, + data VARCHAR(100) NULL + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_statistics_basic(cursor, db_connection): + """Test basic functionality of statistics method""" + try: + # First set up our test tables + test_statistics_setup(cursor, db_connection) + + # Get statistics for the test table (all indexes) + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Verify we got results - should include PK, unique index on email, and non-unique index + assert stats is not None, "statistics() should return results" + assert len(stats) > 0, "statistics() should return at least one row" + + # Count different types of indexes + table_stats = [s for s in stats if s.type == 0] # TABLE_STAT + indexes = [s for s in stats if s.type != 0] # Actual indexes + + # We should have at least one table statistics row and multiple index rows + assert len(table_stats) <= 1, "Should have at most one TABLE_STAT row" + assert len(indexes) >= 3, "Should have at least 3 index entries (PK, unique email, non-unique dept+date)" + + # Verify column names in results + first_row = stats[0] + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'non_unique'), "Result should have non_unique column" + assert hasattr(first_row, 'index_name'), "Result should have index_name column" + assert hasattr(first_row, 'type'), "Result should have type column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + + # Check that we can find the primary key + pk_found = False + for stat in stats: + if (hasattr(stat, 'index_name') and + stat.index_name and + 'pk' in stat.index_name.lower()): + pk_found = True + break + + assert pk_found, "Primary key should be included in statistics results" + + # Check that we can find the unique index on email + email_index_found = False + for stat in stats: + if (hasattr(stat, 'column_name') and + stat.column_name and + stat.column_name.lower() == 'email' and + hasattr(stat, 'non_unique') and + stat.non_unique == 0): # 0 = unique + email_index_found = True + break + + assert email_index_found, "Unique index on email should be included in statistics results" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_unique_only(cursor, db_connection): + """Test statistics with unique=True to get only unique indexes""" + try: + # Get statistics for only unique indexes + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + unique=True + ).fetchall() + + # Verify we got results + assert stats is not None, "statistics() with unique=True should return results" + assert len(stats) > 0, "statistics() with unique=True should return at least one row" + + # All index entries should be for unique indexes (non_unique = 0) + for stat in stats: + if hasattr(stat, 'type') and stat.type != 0: # Skip TABLE_STAT entries + assert hasattr(stat, 'non_unique'), "Index entry should have non_unique column" + assert stat.non_unique == 0, "With unique=True, all indexes should be unique" + + # Count different types of indexes + indexes = [s for s in stats if hasattr(s, 'type') and s.type != 0] + + # We should have multiple unique indexes (PK, unique email, unique name+dept) + assert len(indexes) >= 3, "Should have at least 3 unique index entries" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_empty_table(cursor, db_connection): + """Test statistics on a table with no data (just schema)""" + try: + # Get statistics for the empty table + stats = cursor.statistics( + table='empty_stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Should still return metadata about the primary key + assert stats is not None, "statistics() should return results even for empty table" + assert len(stats) > 0, "statistics() should return at least one row for empty table" + + # Check for primary key + pk_found = False + for stat in stats: + if (hasattr(stat, 'index_name') and + stat.index_name and + 'pk' in stat.index_name.lower()): + pk_found = True + break + + assert pk_found, "Primary key should be included in statistics results for empty table" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_nonexistent(cursor): + """Test statistics with non-existent table name""" + # Use a table name that's highly unlikely to exist + stats = cursor.statistics('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(stats, list), "Should return a list for non-existent table" + assert len(stats) == 0, "Should return empty list for non-existent table" + +def test_statistics_result_structure(cursor, db_connection): + """Test the complete structure of statistics result rows""" + try: + # Get statistics for the test table + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Verify we have results + assert len(stats) > 0, "Should have statistics results" + + # Find a row that's an actual index (not TABLE_STAT) + index_row = None + for stat in stats: + if hasattr(stat, 'type') and stat.type != 0: + index_row = stat + break + + assert index_row is not None, "Should have at least one index row" + + # Check for all required columns + required_columns = [ + 'table_cat', 'table_schem', 'table_name', 'non_unique', + 'index_qualifier', 'index_name', 'type', 'ordinal_position', + 'column_name', 'asc_or_desc', 'cardinality', 'pages', + 'filter_condition' + ] + + for column in required_columns: + assert hasattr(index_row, column), f"Result missing required column: {column}" + + # Check types of key columns + assert isinstance(index_row.table_name, str), "table_name should be a string" + assert isinstance(index_row.type, int), "type should be an integer" + + # Don't check the actual values of cardinality and pages as they may be NULL + # or driver-dependent, especially for empty tables + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_catalog_filter(cursor, db_connection): + """Test statistics with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get statistics with current catalog + stats = cursor.statistics( + table='stats_test', + catalog=current_db, + schema='pytest_stats_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(stats) > 0, "Should find statistics with correct catalog" + + # Verify catalog in results + for stat in stats: + if hasattr(stat, 'table_cat'): + assert stat.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Get statistics with non-existent catalog + fake_stats = cursor.statistics( + table='stats_test', + catalog='nonexistent_db_xyz123', + schema='pytest_stats_schema' + ).fetchall() + assert len(fake_stats) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_with_quick_parameter(cursor, db_connection): + """Test statistics with quick parameter variations""" + try: + # Test with quick=True (default) + quick_stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + quick=True + ).fetchall() + + # Test with quick=False + thorough_stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + quick=False + ).fetchall() + + # Both should return results, but we can't guarantee behavior differences + # since it depends on the ODBC driver and database system + assert len(quick_stats) > 0, "quick=True should return results" + assert len(thorough_stats) > 0, "quick=False should return results" + + # Just verify that changing the parameter didn't cause errors + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.stats_test") + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.empty_stats_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_stats_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_columns_setup(cursor, db_connection): + """Create test tables for columns method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_cols_schema') EXEC('CREATE SCHEMA pytest_cols_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Create test table with various column types + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + description NVARCHAR(MAX) NULL, + price DECIMAL(10, 2) NULL, + created_date DATETIME DEFAULT GETDATE(), + is_active BIT NOT NULL DEFAULT 1, + binary_data VARBINARY(MAX) NULL, + notes TEXT NULL, + [computed_col] AS (name + ' - ' + CAST(id AS VARCHAR(10))) + ) + """) + + # Create table with special column names and edge cases - fix the problematic column name + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_special_test ( + [ID] INT PRIMARY KEY, + [User Name] NVARCHAR(100) NULL, + [Spaces Multiple] VARCHAR(50) NULL, + [123_numeric_start] INT NULL, + [MAX] VARCHAR(20) NULL, -- SQL keyword as column name + [SELECT] INT NULL, -- SQL keyword as column name + [Column.With.Dots] VARCHAR(20) NULL, + [Column/With/Slashes] VARCHAR(20) NULL, + [Column_With_Underscores] VARCHAR(20) NULL -- Changed from problematic nested brackets + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_columns_all(cursor, db_connection): + """Test columns returns information about all columns in all tables""" + try: + # First set up our test tables + test_columns_setup(cursor, db_connection) + + # Get all columns (no filters) + cols_cursor = cursor.columns() + cols = cols_cursor.fetchall() + + # Verify we got results + assert cols is not None, "columns() should return results" + assert len(cols) > 0, "columns() should return at least one column" + + # Verify our test tables' columns are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for col in cols: + if (hasattr(col, 'table_name') and + col.table_name and + col.table_name.lower() == 'columns_test' and + hasattr(col, 'table_schem') and + col.table_schem and + col.table_schem.lower() == 'pytest_cols_schema'): + found_test_table = True + break + + assert found_test_table, "Test table columns should be included in results" + + # Verify structure of results + first_row = cols[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + assert hasattr(first_row, 'data_type'), "Result should have data_type column" + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + assert hasattr(first_row, 'column_size'), "Result should have column_size column" + assert hasattr(first_row, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(first_row, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(first_row, 'num_prec_radix'), "Result should have num_prec_radix column" + assert hasattr(first_row, 'nullable'), "Result should have nullable column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + assert hasattr(first_row, 'column_def'), "Result should have column_def column" + assert hasattr(first_row, 'sql_data_type'), "Result should have sql_data_type column" + assert hasattr(first_row, 'sql_datetime_sub'), "Result should have sql_datetime_sub column" + assert hasattr(first_row, 'char_octet_length'), "Result should have char_octet_length column" + assert hasattr(first_row, 'ordinal_position'), "Result should have ordinal_position column" + assert hasattr(first_row, 'is_nullable'), "Result should have is_nullable column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_table(cursor, db_connection): + """Test columns returns information about a specific table""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_test" + + # Verify all column names are present (case insensitive) + col_names = [col.column_name.lower() for col in cols] + expected_names = ['id', 'name', 'description', 'price', 'created_date', + 'is_active', 'binary_data', 'notes', 'computed_col'] + + for name in expected_names: + assert name in col_names, f"Column {name} should be in results" + + # Verify details of a specific column (id) + id_col = next(col for col in cols if col.column_name.lower() == 'id') + assert id_col.nullable == 0, "id column should be non-nullable" + assert id_col.ordinal_position == 1, "id should be the first column" + assert id_col.is_nullable == "NO", "is_nullable should be NO for id column" + + # Check data types (but don't assume specific ODBC type codes since they vary by driver) + # Instead check that the type_name is correct + id_type = id_col.type_name.lower() + assert 'int' in id_type, f"id column should be INTEGER type, got {id_type}" + + # Check a nullable column + desc_col = next(col for col in cols if col.column_name.lower() == 'description') + assert desc_col.nullable == 1, "description column should be nullable" + assert desc_col.is_nullable == "YES", "is_nullable should be YES for description column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_special_chars(cursor, db_connection): + """Test columns with special characters and edge cases""" + try: + # Get columns for the special table + cols = cursor.columns( + table='columns_special_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_special_test" + + # Check that special column names are handled correctly + col_names = [col.column_name for col in cols] + + # Create case-insensitive lookup + col_names_lower = [name.lower() if name else None for name in col_names] + + # Check for columns with special characters - note that column names might be + # returned with or without brackets/quotes depending on the driver + assert any('user name' in name.lower() for name in col_names), "Column with spaces should be in results" + assert any('id' == name.lower() for name in col_names), "ID column should be in results" + assert any('123_numeric_start' in name.lower() for name in col_names), "Column starting with numbers should be in results" + assert any('max' == name.lower() for name in col_names), "MAX column should be in results" + assert any('select' == name.lower() for name in col_names), "SELECT column should be in results" + assert any('column.with.dots' in name.lower() for name in col_names), "Column with dots should be in results" + assert any('column/with/slashes' in name.lower() for name in col_names), "Column with slashes should be in results" + assert any('column_with_underscores' in name.lower() for name in col_names), "Column with underscores should be in results" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_column(cursor, db_connection): + """Test columns with specific column filter""" + try: + # Get specific column + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='name' + ).fetchall() + + # Verify we got just one result + assert len(cols) == 1, "Should find exactly 1 column named 'name'" + + # Verify column details + col = cols[0] + assert col.column_name.lower() == 'name', "Column name should be 'name'" + assert col.table_name.lower() == 'columns_test', "Table name should be 'columns_test'" + assert col.table_schem.lower() == 'pytest_cols_schema', "Schema should be 'pytest_cols_schema'" + assert col.nullable == 0, "name column should be non-nullable" + + # Get column using pattern (% wildcard) + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%date%' + ).fetchall() + + # Should find created_date column + assert len(pattern_cols) == 1, "Should find 1 column matching '%date%'" + + assert pattern_cols[0].column_name.lower() == 'created_date', "Should find created_date column" + + # Get multiple columns with pattern + multi_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%d%' # Should match id, description, created_date + ).fetchall() + + # At least 3 columns should match this pattern + assert len(multi_cols) >= 3, "Should find at least 3 columns matching '%d%'" + match_names = [col.column_name.lower() for col in multi_cols] + assert 'id' in match_names, "id should match '%d%'" + assert 'description' in match_names, "description should match '%d%'" + assert 'created_date' in match_names, "created_date should match '%d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_with_underscore_pattern(cursor): + """Test columns with underscore wildcard pattern""" + try: + # Get columns with underscore pattern (one character wildcard) + # Looking for 'id' (exactly 2 chars) + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='__' + ).fetchall() + + # Should find 'id' column + id_found = False + for col in cols: + if col.column_name.lower() == 'id' and col.table_name.lower() == 'columns_test': + id_found = True + break + + assert id_found, "Should find 'id' column with pattern '__'" + + # Try a more complex pattern with both % and _ + # For example: '%_d%' matches any column with 'd' as the second or later character + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%_d%' + ).fetchall() + + # Should match 'id' (if considering case-insensitive) and 'created_date' + match_names = [col.column_name.lower() for col in pattern_cols + if col.table_name.lower() == 'columns_test'] + + # At least 'created_date' should match this pattern + assert 'created_date' in match_names, "created_date should match '%_d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123') + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ) + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_data_types(cursor): + """Test columns returns correct data type information""" + try: + # Get all columns from test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Create a dictionary mapping column names to their details + col_dict = {col.column_name.lower(): col for col in cols} + + # Check data types by name (case insensitive checks) + # Note: We're checking type_name as a string to avoid SQL type code inconsistencies + # between drivers + + # INT column + assert 'int' in col_dict['id'].type_name.lower(), "id should be INT type" + + # NVARCHAR column + assert any(name in col_dict['name'].type_name.lower() + for name in ['nvarchar', 'varchar', 'char', 'wchar']), "name should be NVARCHAR type" + + # DECIMAL column + assert any(name in col_dict['price'].type_name.lower() + for name in ['decimal', 'numeric', 'money']), "price should be DECIMAL type" + + # BIT column + assert any(name in col_dict['is_active'].type_name.lower() + for name in ['bit', 'boolean']), "is_active should be BIT type" + + # TEXT column + assert any(name in col_dict['notes'].type_name.lower() + for name in ['text', 'char', 'varchar']), "notes should be TEXT type" + + # Check nullable flag + assert col_dict['id'].nullable == 0, "id should be non-nullable" + assert col_dict['description'].nullable == 1, "description should be nullable" + + # Check column size + assert col_dict['name'].column_size == 100, "name should have size 100" + + # Check decimal digits for numeric type + assert col_dict['price'].decimal_digits == 2, "price should have 2 decimal digits" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123').fetchall() + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ).fetchall() + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_catalog_filter(cursor): + """Test columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get columns with current catalog + cols = cursor.columns( + table='columns_test', + catalog=current_db, + schema='pytest_cols_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(cols) > 0, "Should find columns with correct catalog" + + # Check catalog in results + for col in cols: + # Some drivers might return None for catalog + if col.table_cat is not None: + assert col.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_cols = cursor.columns( + table='columns_test', + catalog='nonexistent_db_xyz123', + schema='pytest_cols_schema' + ).fetchall() + assert len(fake_cols) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_schema_pattern(cursor): + """Test columns with schema name pattern""" + try: + # Get columns with schema pattern + cols = cursor.columns( + table='columns_test', + schema='pytest_%' + ).fetchall() + + # Should find our test table columns + test_cols = [col for col in cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using schema pattern" + + # Try a more specific pattern + specific_cols = cursor.columns( + table='columns_test', + schema='pytest_cols%' + ).fetchall() + + # Should still find our test table columns + test_cols = [col for col in specific_cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using specific schema pattern" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_table_pattern(cursor): + """Test columns with table name pattern""" + try: + # Get columns with table pattern + cols = cursor.columns( + table='columns_%', + schema='pytest_cols_schema' + ).fetchall() + + # Should find columns from both test tables + tables_found = set() + for col in cols: + if col.table_name: + tables_found.add(col.table_name.lower()) + + assert 'columns_test' in tables_found, "Should find columns_test with pattern columns_%" + assert 'columns_special_test' in tables_found, "Should find columns_special_test with pattern columns_%" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_ordinal_position(cursor): + """Test ordinal_position is correct in columns results""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Sort by ordinal position + sorted_cols = sorted(cols, key=lambda col: col.ordinal_position) + + # Verify positions are consecutive starting from 1 + for i, col in enumerate(sorted_cols, 1): + assert col.ordinal_position == i, f"Column {col.column_name} should have ordinal_position {i}" + + # First column should be id (primary key) + assert sorted_cols[0].column_name.lower() == 'id', "First column should be id" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_cols_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") def test_close(db_connection): """Test closing the cursor"""