diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 8e118cd5b..071136462 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -50,6 +50,7 @@ # Export specific constants for setencoding() SQL_CHAR = ConstantsDDBC.SQL_CHAR.value SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value +SQL_WMETADATA = -99 # GLOBALS # Read-Only diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 9b19a603a..187f3e33e 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -21,6 +21,9 @@ from mssql_python.auth import process_connection_string from mssql_python.constants import ConstantsDDBC +# Add SQL_WMETADATA constant for metadata decoding configuration +SQL_WMETADATA = -99 # Special flag for column name decoding + # UTF-16 encoding variants that should use SQL_WCHAR by default UTF16_ENCODINGS = frozenset([ 'utf-16', @@ -28,7 +31,6 @@ 'utf-16be' ]) - def _validate_encoding(encoding: str) -> bool: """ Cached encoding validation using codecs.lookup(). @@ -67,6 +69,8 @@ class Connection: rollback() -> None: close() -> None: setencoding(encoding=None, ctype=None) -> None: + setdecoding(sqltype, encoding=None, ctype=None) -> None: + getdecoding(sqltype) -> dict: """ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None: @@ -101,6 +105,22 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef 'ctype': ConstantsDDBC.SQL_WCHAR.value } + # Initialize decoding settings with Python 3 defaults + self._decoding_settings = { + ConstantsDDBC.SQL_CHAR.value: { + 'encoding': 'utf-8', + 'ctype': ConstantsDDBC.SQL_CHAR.value + }, + ConstantsDDBC.SQL_WCHAR.value: { + 'encoding': 'utf-16le', + 'ctype': ConstantsDDBC.SQL_WCHAR.value + }, + SQL_WMETADATA: { + 'encoding': 'utf-16le', + 'ctype': ConstantsDDBC.SQL_WCHAR.value + } + } + # Check if the connection string contains authentication parameters # This is important for processing the connection string correctly. # If authentication is specified, it will be processed to handle @@ -297,6 +317,147 @@ def getencoding(self): return self._encoding_settings.copy() + def setdecoding(self, sqltype, encoding=None, ctype=None): + """ + Sets the text decoding used when reading SQL_CHAR and SQL_WCHAR from the database. + + This method configures how text data is decoded when reading from the database. + In Python 3, all text is Unicode (str), so this primarily affects the encoding + used to decode bytes from the database. + + Args: + sqltype (int): The SQL type being configured: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA. + SQL_WMETADATA is a special flag for configuring column name decoding. + encoding (str, optional): The Python encoding to use when decoding the data. + If None, uses default encoding based on sqltype. + ctype (int, optional): The C data type to request from SQLGetData: + SQL_CHAR or SQL_WCHAR. If None, uses default based on encoding. + + Returns: + None + + Raises: + ProgrammingError: If the sqltype, encoding, or ctype is invalid. + InterfaceError: If the connection is closed. + + Example: + # Configure SQL_CHAR to use UTF-8 decoding + cnxn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Configure column metadata decoding + cnxn.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le') + + # Use explicit ctype + cnxn.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Connection is closed", + ) + + # Validate sqltype + valid_sqltypes = [ + ConstantsDDBC.SQL_CHAR.value, + ConstantsDDBC.SQL_WCHAR.value, + SQL_WMETADATA + ] + if sqltype not in valid_sqltypes: + log('warning', "Invalid sqltype attempted: %s", sanitize_user_input(str(sqltype))) + raise ProgrammingError( + driver_error=f"Invalid sqltype: {sqltype}", + ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})", + ) + + # Set default encoding based on sqltype if not provided + if encoding is None: + if sqltype == ConstantsDDBC.SQL_CHAR.value: + encoding = 'utf-8' # Default for SQL_CHAR in Python 3 + else: # SQL_WCHAR or SQL_WMETADATA + encoding = 'utf-16le' # Default for SQL_WCHAR in Python 3 + + # Validate encoding using cached validation for better performance + if not _validate_encoding(encoding): + log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding))) + raise ProgrammingError( + driver_error=f"Unsupported encoding: {encoding}", + ddbc_error=f"The encoding '{encoding}' is not supported by Python", + ) + + # Normalize encoding to lowercase for consistency + encoding = encoding.lower() + + # Set default ctype based on encoding if not provided + if ctype is None: + if encoding in UTF16_ENCODINGS: + ctype = ConstantsDDBC.SQL_WCHAR.value + else: + ctype = ConstantsDDBC.SQL_CHAR.value + + # Validate ctype + valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value] + if ctype not in valid_ctypes: + log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype))) + raise ProgrammingError( + driver_error=f"Invalid ctype: {ctype}", + ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})", + ) + + # Store the decoding settings for the specified sqltype + self._decoding_settings[sqltype] = { + 'encoding': encoding, + 'ctype': ctype + } + + # Log with sanitized values for security + sqltype_name = { + ConstantsDDBC.SQL_CHAR.value: "SQL_CHAR", + ConstantsDDBC.SQL_WCHAR.value: "SQL_WCHAR", + SQL_WMETADATA: "SQL_WMETADATA" + }.get(sqltype, str(sqltype)) + + log('info', "Text decoding set for %s to %s with ctype %s", + sqltype_name, sanitize_user_input(encoding), sanitize_user_input(str(ctype))) + + def getdecoding(self, sqltype): + """ + Gets the current text decoding settings for the specified SQL type. + + Args: + sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA. + + Returns: + dict: A dictionary containing 'encoding' and 'ctype' keys for the specified sqltype. + + Raises: + ProgrammingError: If the sqltype is invalid. + InterfaceError: If the connection is closed. + + Example: + settings = cnxn.getdecoding(mssql_python.SQL_CHAR) + print(f"SQL_CHAR encoding: {settings['encoding']}") + print(f"SQL_CHAR ctype: {settings['ctype']}") + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Connection is closed", + ) + + # Validate sqltype + valid_sqltypes = [ + ConstantsDDBC.SQL_CHAR.value, + ConstantsDDBC.SQL_WCHAR.value, + SQL_WMETADATA + ] + if sqltype not in valid_sqltypes: + raise ProgrammingError( + driver_error=f"Invalid sqltype: {sqltype}", + ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})", + ) + + return self._decoding_settings[sqltype].copy() + def cursor(self) -> Cursor: """ Return a new Cursor object using the connection. diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index c9d043d38..74cccdc92 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -18,7 +18,8 @@ - test_rollback_on_close: Test that rollback occurs on connection close if autocommit is False. """ -from mssql_python.exceptions import InterfaceError +from mssql_python.exceptions import InterfaceError, ProgrammingError +import mssql_python import pytest import time from mssql_python import connect, Connection, pooling, SQL_CHAR, SQL_WCHAR @@ -552,17 +553,15 @@ def test_setencoding_none_parameters(db_connection): def test_setencoding_invalid_encoding(db_connection): """Test setencoding with invalid encoding.""" - from mssql_python.exceptions import ProgrammingError with pytest.raises(ProgrammingError) as exc_info: db_connection.setencoding(encoding='invalid-encoding-name') - assert "Unknown encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" + assert "Unsupported encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" assert "invalid-encoding-name" in str(exc_info.value), "Error message should include the invalid encoding name" def test_setencoding_invalid_ctype(db_connection): """Test setencoding with invalid ctype.""" - from mssql_python.exceptions import ProgrammingError with pytest.raises(ProgrammingError) as exc_info: db_connection.setencoding(encoding='utf-8', ctype=999) @@ -572,7 +571,6 @@ def test_setencoding_invalid_ctype(db_connection): def test_setencoding_closed_connection(conn_str): """Test setencoding on closed connection.""" - from mssql_python.exceptions import InterfaceError temp_conn = connect(conn_str) temp_conn.close() @@ -580,7 +578,7 @@ def test_setencoding_closed_connection(conn_str): with pytest.raises(InterfaceError) as exc_info: temp_conn.setencoding(encoding='utf-8') - assert "closed connection" in str(exc_info.value).lower(), "Should raise InterfaceError for closed connection" + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" def test_setencoding_constants_access(): """Test that SQL_CHAR and SQL_WCHAR constants are accessible.""" @@ -829,20 +827,8 @@ def test_setencoding_with_explicit_ctype_sql_wchar(conn_str): finally: conn.close() -def test_setencoding_invalid_encoding(conn_str): - """Test setencoding with invalid encoding raises ProgrammingError""" - from mssql_python.exceptions import ProgrammingError - - conn = connect(conn_str) - try: - with pytest.raises(ProgrammingError, match="Unsupported encoding"): - conn.setencoding('invalid-encoding-name') - finally: - conn.close() - -def test_setencoding_invalid_ctype(conn_str): +def test_setencoding_invalid_ctype_error(conn_str): """Test setencoding with invalid ctype raises ProgrammingError""" - from mssql_python.exceptions import ProgrammingError conn = connect(conn_str) try: @@ -851,14 +837,6 @@ def test_setencoding_invalid_ctype(conn_str): finally: conn.close() -def test_setencoding_closed_connection(conn_str): - """Test setencoding on closed connection raises InterfaceError""" - conn = connect(conn_str) - conn.close() - - with pytest.raises(InterfaceError, match="Connection is closed"): - conn.setencoding('utf-8') - def test_setencoding_case_insensitive_encoding(conn_str): """Test setencoding with case variations""" conn = connect(conn_str) @@ -923,4 +901,431 @@ def test_setencoding_cp1252(conn_str): assert encoding_info['encoding'] == 'cp1252' assert encoding_info['ctype'] == SQL_CHAR finally: - conn.close() \ No newline at end of file + conn.close() + +def test_setdecoding_default_settings(db_connection): + """Test that default decoding settings are correct for all SQL types.""" + + # Check SQL_CHAR defaults + sql_char_settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert sql_char_settings['encoding'] == 'utf-8', "Default SQL_CHAR encoding should be utf-8" + assert sql_char_settings['ctype'] == mssql_python.SQL_CHAR, "Default SQL_CHAR ctype should be SQL_CHAR" + + # Check SQL_WCHAR defaults + sql_wchar_settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert sql_wchar_settings['encoding'] == 'utf-16le', "Default SQL_WCHAR encoding should be utf-16le" + assert sql_wchar_settings['ctype'] == mssql_python.SQL_WCHAR, "Default SQL_WCHAR ctype should be SQL_WCHAR" + + # Check SQL_WMETADATA defaults + sql_wmetadata_settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert sql_wmetadata_settings['encoding'] == 'utf-16le', "Default SQL_WMETADATA encoding should be utf-16le" + assert sql_wmetadata_settings['ctype'] == mssql_python.SQL_WCHAR, "Default SQL_WMETADATA ctype should be SQL_WCHAR" + +def test_setdecoding_basic_functionality(db_connection): + """Test basic setdecoding functionality for different SQL types.""" + + # Test setting SQL_CHAR decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "SQL_CHAR encoding should be set to latin-1" + assert settings['ctype'] == mssql_python.SQL_CHAR, "SQL_CHAR ctype should default to SQL_CHAR for latin-1" + + # Test setting SQL_WCHAR decoding + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16be') + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16be', "SQL_WCHAR encoding should be set to utf-16be" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "SQL_WCHAR ctype should default to SQL_WCHAR for utf-16be" + + # Test setting SQL_WMETADATA decoding + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le') + settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert settings['encoding'] == 'utf-16le', "SQL_WMETADATA encoding should be set to utf-16le" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "SQL_WMETADATA ctype should default to SQL_WCHAR" + +def test_setdecoding_automatic_ctype_detection(db_connection): + """Test automatic ctype detection based on encoding for different SQL types.""" + + # UTF-16 variants should default to SQL_WCHAR + utf16_encodings = ['utf-16', 'utf-16le', 'utf-16be'] + for encoding in utf16_encodings: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['ctype'] == mssql_python.SQL_WCHAR, f"SQL_CHAR with {encoding} should auto-detect SQL_WCHAR ctype" + + # Other encodings should default to SQL_CHAR + other_encodings = ['utf-8', 'latin-1', 'ascii', 'cp1252'] + for encoding in other_encodings: + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['ctype'] == mssql_python.SQL_CHAR, f"SQL_WCHAR with {encoding} should auto-detect SQL_CHAR ctype" + +def test_setdecoding_explicit_ctype_override(db_connection): + """Test that explicit ctype parameter overrides automatic detection.""" + + # Set SQL_CHAR with UTF-8 encoding but explicit SQL_WCHAR ctype + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Encoding should be utf-8" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be SQL_WCHAR when explicitly set" + + # Set SQL_WCHAR with UTF-16LE encoding but explicit SQL_CHAR ctype + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "Encoding should be utf-16le" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should be SQL_CHAR when explicitly set" + +def test_setdecoding_none_parameters(db_connection): + """Test setdecoding with None parameters uses appropriate defaults.""" + + # Test SQL_CHAR with encoding=None (should use utf-8 default) + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=None) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "SQL_CHAR with encoding=None should use utf-8 default" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should be SQL_CHAR for utf-8" + + # Test SQL_WCHAR with encoding=None (should use utf-16le default) + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=None) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "SQL_WCHAR with encoding=None should use utf-16le default" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be SQL_WCHAR for utf-16le" + + # Test with both parameters None + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=None, ctype=None) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "SQL_CHAR with both None should use utf-8 default" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should default to SQL_CHAR" + +def test_setdecoding_invalid_sqltype(db_connection): + """Test setdecoding with invalid sqltype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(999, encoding='utf-8') + + assert "Invalid sqltype" in str(exc_info.value), "Should raise ProgrammingError for invalid sqltype" + assert "999" in str(exc_info.value), "Error message should include the invalid sqltype value" + +def test_setdecoding_invalid_encoding(db_connection): + """Test setdecoding with invalid encoding raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='invalid-encoding-name') + + assert "Unsupported encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" + assert "invalid-encoding-name" in str(exc_info.value), "Error message should include the invalid encoding name" + +def test_setdecoding_invalid_ctype(db_connection): + """Test setdecoding with invalid ctype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=999) + + assert "Invalid ctype" in str(exc_info.value), "Should raise ProgrammingError for invalid ctype" + assert "999" in str(exc_info.value), "Error message should include the invalid ctype value" + +def test_setdecoding_closed_connection(conn_str): + """Test setdecoding on closed connection raises InterfaceError.""" + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" + +def test_setdecoding_constants_access(): + """Test that SQL constants are accessible.""" + + # Test constants exist and have correct values + assert hasattr(mssql_python, 'SQL_CHAR'), "SQL_CHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WCHAR'), "SQL_WCHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WMETADATA'), "SQL_WMETADATA constant should be available" + + assert mssql_python.SQL_CHAR == 1, "SQL_CHAR should have value 1" + assert mssql_python.SQL_WCHAR == -8, "SQL_WCHAR should have value -8" + assert mssql_python.SQL_WMETADATA == -99, "SQL_WMETADATA should have value -99" + +def test_setdecoding_with_constants(db_connection): + """Test setdecoding using module constants.""" + + # Test with SQL_CHAR constant + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['ctype'] == mssql_python.SQL_CHAR, "Should accept SQL_CHAR constant" + + # Test with SQL_WCHAR constant + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['ctype'] == mssql_python.SQL_WCHAR, "Should accept SQL_WCHAR constant" + + # Test with SQL_WMETADATA constant + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16be') + settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert settings['encoding'] == 'utf-16be', "Should accept SQL_WMETADATA constant" + +def test_setdecoding_common_encodings(db_connection): + """Test setdecoding with various common encodings.""" + + common_encodings = [ + 'utf-8', + 'utf-16le', + 'utf-16be', + 'utf-16', + 'latin-1', + 'ascii', + 'cp1252' + ] + + for encoding in common_encodings: + try: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == encoding, f"Failed to set SQL_CHAR decoding to {encoding}" + + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == encoding, f"Failed to set SQL_WCHAR decoding to {encoding}" + except Exception as e: + pytest.fail(f"Failed to set valid encoding {encoding}: {e}") + +def test_setdecoding_case_insensitive_encoding(db_connection): + """Test setdecoding with case variations normalizes encoding.""" + + # Test various case formats + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='UTF-8') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Encoding should be normalized to lowercase" + + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='Utf-16LE') + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "Encoding should be normalized to lowercase" + +def test_setdecoding_independent_sql_types(db_connection): + """Test that decoding settings for different SQL types are independent.""" + + # Set different encodings for each SQL type + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le') + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16be') + + # Verify each maintains its own settings + sql_char_settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + sql_wchar_settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + sql_wmetadata_settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + + assert sql_char_settings['encoding'] == 'utf-8', "SQL_CHAR should maintain utf-8" + assert sql_wchar_settings['encoding'] == 'utf-16le', "SQL_WCHAR should maintain utf-16le" + assert sql_wmetadata_settings['encoding'] == 'utf-16be', "SQL_WMETADATA should maintain utf-16be" + +def test_setdecoding_override_previous(db_connection): + """Test setdecoding overrides previous settings for the same SQL type.""" + + # Set initial decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Initial encoding should be utf-8" + assert settings['ctype'] == mssql_python.SQL_CHAR, "Initial ctype should be SQL_CHAR" + + # Override with different settings + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "Encoding should be overridden to latin-1" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be overridden to SQL_WCHAR" + +def test_getdecoding_invalid_sqltype(db_connection): + """Test getdecoding with invalid sqltype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.getdecoding(999) + + assert "Invalid sqltype" in str(exc_info.value), "Should raise ProgrammingError for invalid sqltype" + assert "999" in str(exc_info.value), "Error message should include the invalid sqltype value" + +def test_getdecoding_closed_connection(conn_str): + """Test getdecoding on closed connection raises InterfaceError.""" + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.getdecoding(mssql_python.SQL_CHAR) + + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" + +def test_getdecoding_returns_copy(db_connection): + """Test getdecoding returns a copy (not reference).""" + + # Set custom decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Get settings twice + settings1 = db_connection.getdecoding(mssql_python.SQL_CHAR) + settings2 = db_connection.getdecoding(mssql_python.SQL_CHAR) + + # Should be equal but not the same object + assert settings1 == settings2, "Settings should be equal" + assert settings1 is not settings2, "Settings should be different objects" + + # Modifying one shouldn't affect the other + settings1['encoding'] = 'modified' + assert settings2['encoding'] != 'modified', "Modification should not affect other copy" + +def test_setdecoding_getdecoding_consistency(db_connection): + """Test that setdecoding and getdecoding work consistently together.""" + + test_cases = [ + (mssql_python.SQL_CHAR, 'utf-8', mssql_python.SQL_CHAR), + (mssql_python.SQL_CHAR, 'utf-16le', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WCHAR, 'latin-1', mssql_python.SQL_CHAR), + (mssql_python.SQL_WCHAR, 'utf-16be', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WMETADATA, 'utf-16le', mssql_python.SQL_WCHAR), + ] + + for sqltype, encoding, expected_ctype in test_cases: + db_connection.setdecoding(sqltype, encoding=encoding) + settings = db_connection.getdecoding(sqltype) + assert settings['encoding'] == encoding.lower(), f"Encoding should be {encoding.lower()}" + assert settings['ctype'] == expected_ctype, f"ctype should be {expected_ctype}" + +def test_setdecoding_persistence_across_cursors(db_connection): + """Test that decoding settings persist across cursor operations.""" + + # Set custom decoding settings + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1', ctype=mssql_python.SQL_CHAR) + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16be', ctype=mssql_python.SQL_WCHAR) + + # Create cursors and verify settings persist + cursor1 = db_connection.cursor() + char_settings1 = db_connection.getdecoding(mssql_python.SQL_CHAR) + wchar_settings1 = db_connection.getdecoding(mssql_python.SQL_WCHAR) + + cursor2 = db_connection.cursor() + char_settings2 = db_connection.getdecoding(mssql_python.SQL_CHAR) + wchar_settings2 = db_connection.getdecoding(mssql_python.SQL_WCHAR) + + # Settings should persist across cursor creation + assert char_settings1 == char_settings2, "SQL_CHAR settings should persist across cursors" + assert wchar_settings1 == wchar_settings2, "SQL_WCHAR settings should persist across cursors" + + assert char_settings1['encoding'] == 'latin-1', "SQL_CHAR encoding should remain latin-1" + assert wchar_settings1['encoding'] == 'utf-16be', "SQL_WCHAR encoding should remain utf-16be" + + cursor1.close() + cursor2.close() + +def test_setdecoding_before_and_after_operations(db_connection): + """Test that setdecoding works both before and after database operations.""" + cursor = db_connection.cursor() + + try: + # Initial decoding setting + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Perform database operation + cursor.execute("SELECT 'Initial test' as message") + result1 = cursor.fetchone() + assert result1[0] == 'Initial test', "Initial operation failed" + + # Change decoding after operation + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "Failed to change decoding after operation" + + # Perform another operation with new decoding + cursor.execute("SELECT 'Changed decoding test' as message") + result2 = cursor.fetchone() + assert result2[0] == 'Changed decoding test', "Operation after decoding change failed" + + except Exception as e: + pytest.fail(f"Decoding change test failed: {e}") + finally: + cursor.close() + +def test_setdecoding_all_sql_types_independently(conn_str): + """Test setdecoding with all SQL types on a fresh connection.""" + + conn = connect(conn_str) + try: + # Test each SQL type with different configurations + test_configs = [ + (mssql_python.SQL_CHAR, 'ascii', mssql_python.SQL_CHAR), + (mssql_python.SQL_WCHAR, 'utf-16le', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WMETADATA, 'utf-16be', mssql_python.SQL_WCHAR), + ] + + for sqltype, encoding, ctype in test_configs: + conn.setdecoding(sqltype, encoding=encoding, ctype=ctype) + settings = conn.getdecoding(sqltype) + assert settings['encoding'] == encoding, f"Failed to set encoding for sqltype {sqltype}" + assert settings['ctype'] == ctype, f"Failed to set ctype for sqltype {sqltype}" + + finally: + conn.close() + +def test_setdecoding_security_logging(db_connection): + """Test that setdecoding logs invalid attempts safely.""" + + # These should raise exceptions but not crash due to logging + test_cases = [ + (999, 'utf-8', None), # Invalid sqltype + (mssql_python.SQL_CHAR, 'invalid-encoding', None), # Invalid encoding + (mssql_python.SQL_CHAR, 'utf-8', 999), # Invalid ctype + ] + + for sqltype, encoding, ctype in test_cases: + with pytest.raises(ProgrammingError): + db_connection.setdecoding(sqltype, encoding=encoding, ctype=ctype) + +@pytest.mark.skip("Skipping Unicode data tests till we have support for Unicode") +def test_setdecoding_with_unicode_data(db_connection): + """Test setdecoding with actual Unicode data operations.""" + + # Test different decoding configurations with Unicode data + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le') + + cursor = db_connection.cursor() + + try: + # Create test table with both CHAR and NCHAR columns + cursor.execute(""" + CREATE TABLE #test_decoding_unicode ( + char_col VARCHAR(100), + nchar_col NVARCHAR(100) + ) + """) + + # Test various Unicode strings + test_strings = [ + "Hello, World!", + "Hello, 世界!", # Chinese + "Привет, мир!", # Russian + "مرحبا بالعالم", # Arabic + ] + + for test_string in test_strings: + # Insert data + cursor.execute( + "INSERT INTO #test_decoding_unicode (char_col, nchar_col) VALUES (?, ?)", + test_string, test_string + ) + + # Retrieve and verify + cursor.execute("SELECT char_col, nchar_col FROM #test_decoding_unicode WHERE char_col = ?", test_string) + result = cursor.fetchone() + + assert result is not None, f"Failed to retrieve Unicode string: {test_string}" + assert result[0] == test_string, f"CHAR column mismatch: expected {test_string}, got {result[0]}" + assert result[1] == test_string, f"NCHAR column mismatch: expected {test_string}, got {result[1]}" + + # Clear for next test + cursor.execute("DELETE FROM #test_decoding_unicode") + + except Exception as e: + pytest.fail(f"Unicode data test failed with custom decoding: {e}") + finally: + try: + cursor.execute("DROP TABLE #test_decoding_unicode") + except: + pass + cursor.close() \ No newline at end of file