From d2f2397c75807a84d59498333d6485ecdbf2a79f Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Mon, 14 Jul 2025 12:21:41 +0530 Subject: [PATCH 01/12] FEAT: Adding authentication module and adding new auth types --- README.md | 9 +- mssql_python/auth.py | 172 +++++++++++++++++++++++++++++++++++++ mssql_python/connection.py | 7 ++ 3 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 mssql_python/auth.py diff --git a/README.md b/README.md index 8359ccb64..42c028836 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ By adhering to the DB API 2.0 specification, the mssql-python module ensures com ### Support for Microsoft Entra ID Authentication -The Microsoft mssql-python driver enables Python applications to connect to Microsoft SQL Server, Azure SQL Database, or Azure SQL Managed Instance using Microsoft Entra ID identities. It supports various authentication methods, including username and password, Microsoft Entra managed identity, and Integrated Windows Authentication in a federated, domain-joined environment. Additionally, the driver supports Microsoft Entra interactive authentication and Microsoft Entra managed identity authentication for both system-assigned and user-assigned managed identities. +The Microsoft mssql-python driver enables Python applications to connect to Microsoft SQL Server, Azure SQL Database, or Azure SQL Managed Instance using Microsoft Entra ID identities. It supports a variety of authentication methods, including username and password, Microsoft Entra managed identity (system-assigned and user-assigned), Integrated Windows Authentication in a federated, domain-joined environment, interactive authentication via browser, device code flow for environments without browser access, and the default authentication method based on environment and configuration. This flexibility allows developers to choose the most suitable authentication approach for their deployment scenario. EntraID authentication is now fully supported on MacOS and Linux but with certain limitations as mentioned in the table: @@ -56,9 +56,16 @@ EntraID authentication is now fully supported on MacOS and Linux but with certai |----------------------|----------------|---------------------|-------| | ActiveDirectoryPassword | ✅ Yes | ✅ Yes | Username/password-based authentication | | ActiveDirectoryInteractive | ✅ Yes | ❌ No | Only works on Windows | +| ActiveDirectoryInteractive | ✅ Yes | ✅ Yes | Interactive login via browser; requires user interaction | | ActiveDirectoryMSI (Managed Identity) | ✅ Yes | ✅ Yes | For Azure VMs/containers with managed identity | | ActiveDirectoryServicePrincipal | ✅ Yes | ✅ Yes | Use client ID and secret or certificate | | ActiveDirectoryIntegrated | ✅ Yes | ❌ No | Only works on Windows (requires Kerberos/SSPI) | +| ActiveDirectoryDeviceCode | ✅ Yes | ✅ Yes | Device code flow for authentication; suitable for environments without browser access | +| ActiveDirectoryDefault | ✅ Yes | ✅ Yes | Uses default authentication method based on environment and configuration | + +**NOTE**: For using Access Token, the connection string *must not* contain `UID`, `PWD`, `Authentication`, or `Trusted_Connection` keywords. + +**NOTE**: For using ActiveDirectoryDeviceCode, make sure to specify a `Connect Timeout` that provides enough time to go through the device code flow authentication process. ### Enhanced Pythonic Features diff --git a/mssql_python/auth.py b/mssql_python/auth.py new file mode 100644 index 000000000..5b1da5a6e --- /dev/null +++ b/mssql_python/auth.py @@ -0,0 +1,172 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +This module handles authentication for the mssql_python package. +""" + +import platform +import struct +from typing import Tuple, Dict, Optional, Union +from mssql_python.logging_config import get_logger, ENABLE_LOGGING + +logger = get_logger() + +class AuthType: + """Constants for authentication types""" + INTERACTIVE = "activedirectoryinteractive" + DEVICE_CODE = "activedirectorydevicecode" + DEFAULT = "activedirectorydefault" + +class AADAuth: + """Handles Azure Active Directory authentication""" + + @staticmethod + def get_token_struct(token: str) -> bytes: + """Convert token to SQL Server compatible format""" + token_bytes = token.encode("UTF-16-LE") + return struct.pack(f" bytes: + """Get token using DefaultAzureCredential""" + from azure.identity import DefaultAzureCredential + + try: + # DefaultAzureCredential will automatically use the best available method + # based on the environment (e.g., managed identity, environment variables) + credential = DefaultAzureCredential() + token = credential.get_token("https://database.windows.net/.default").token + return AADAuth.get_token_struct(token) + except Exception as e: + raise RuntimeError(f"Failed to create DefaultAzureCredential: {e}") + + @staticmethod + def get_device_code_token() -> bytes: + """Get token using DeviceCodeCredential""" + from azure.identity import DeviceCodeCredential + + try: + credential = DeviceCodeCredential() + token = credential.get_token("https://database.windows.net/.default").token + return AADAuth.get_token_struct(token) + except Exception as e: + raise RuntimeError(f"Failed to create DeviceCodeCredential: {e}") + + @staticmethod + def get_interactive_token() -> bytes: + """Get token using InteractiveBrowserCredential""" + from azure.identity import InteractiveBrowserCredential + + try: + credential = InteractiveBrowserCredential() + token = credential.get_token("https://database.windows.net/.default").token + return AADAuth.get_token_struct(token) + except Exception as e: + raise RuntimeError(f"Failed to create InteractiveBrowserCredential: {e}") + +def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]: + """ + Process connection parameters and extract authentication type. + + Args: + parameters: List of connection string parameters + + Returns: + Tuple[list, Optional[str]]: Modified parameters and authentication type + + Raises: + ValueError: If an invalid authentication type is provided + """ + modified_parameters = [] + auth_type = None + + for param in parameters: + param = param.strip() + if not param: + continue + + if "=" not in param: + modified_parameters.append(param) + continue + + key, value = param.split("=", 1) + key_lower = key.lower() + value_lower = value.lower() + + if key_lower == "authentication": + if value_lower == AuthType.INTERACTIVE: + auth_type = "interactive" + if platform.system().lower() != "windows": + modified_parameters.append(param) + elif value_lower == AuthType.DEVICE_CODE: + auth_type = "devicecode" + elif value_lower == AuthType.DEFAULT: + auth_type = "default" + else: + raise ValueError(f"Invalid authentication type: {value}. " + f"Supported types are: {AuthType.INTERACTIVE}, " + f"{AuthType.DEVICE_CODE}, {AuthType.DEFAULT}") + else: + modified_parameters.append(param) + + return modified_parameters, auth_type + +def remove_sensitive_params(parameters: list) -> list: + """Remove sensitive parameters from connection string""" + exclude_keys = [ + "uid=", "pwd=", "encrypt=", "trustservercertificate=", "authentication=" + ] + return [ + param for param in parameters + if not any(param.lower().startswith(exclude) for exclude in exclude_keys) + ] + +def get_auth_token(auth_type: str) -> Optional[bytes]: + """Get authentication token based on auth type""" + if not auth_type: + return None + + if auth_type == "default": + return AADAuth.get_default_token() + elif auth_type == "devicecode": + return AADAuth.get_device_code_token() + elif auth_type == "interactive" and platform.system().lower() != "windows": + return AADAuth.get_interactive_token() + return None + +def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict]]: + """ + Process connection string and handle authentication. + + Args: + connection_string: The connection string to process + + Returns: + Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed + + Raises: + ValueError: If the connection string is invalid or empty + """ + # Check type first + if not isinstance(connection_string, str): + raise ValueError("Connection string must be a string") + + # Then check if empty + if not connection_string: + raise ValueError("Connection string cannot be empty") + + parameters = connection_string.split(";") + + # Validate that there's at least one valid parameter + if not any('=' in param for param in parameters): + raise ValueError("Invalid connection string format") + + modified_parameters, auth_type = process_auth_parameters(parameters) + + if auth_type: + modified_parameters = remove_sensitive_params(modified_parameters) + token_struct = get_auth_token(auth_type) + if token_struct: + return ";".join(modified_parameters) + ";", {1256: token_struct} + + return ";".join(modified_parameters) + ";", None \ No newline at end of file diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 2336ae365..c2dcb6daa 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -18,6 +18,7 @@ from mssql_python import ddbc_bindings from mssql_python.pooling import PoolingManager from mssql_python.exceptions import DatabaseError, InterfaceError +from mssql_python.auth import process_connection_string logger = get_logger() @@ -64,6 +65,12 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef connection_str, **kwargs ) self._attrs_before = attrs_before or {} + if "authentication" in self.connection_str.lower(): + connection_result = process_connection_string(self.connection_str) + self.connection_str = connection_result[0] + if connection_result[1]: + self._attrs_before.update(connection_result[1]) + self._closed = False # Using WeakSet which automatically removes cursors when they are no longer in use From 3320013b2ce0292ce0d2253c003819da58c90867 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Mon, 14 Jul 2025 12:41:32 +0530 Subject: [PATCH 02/12] Resolving comments --- README.md | 1 - mssql_python/connection.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 42c028836..32f9eaebb 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,6 @@ EntraID authentication is now fully supported on MacOS and Linux but with certai | Authentication Method | Windows Support | macOS/Linux Support | Notes | |----------------------|----------------|---------------------|-------| | ActiveDirectoryPassword | ✅ Yes | ✅ Yes | Username/password-based authentication | -| ActiveDirectoryInteractive | ✅ Yes | ❌ No | Only works on Windows | | ActiveDirectoryInteractive | ✅ Yes | ✅ Yes | Interactive login via browser; requires user interaction | | ActiveDirectoryMSI (Managed Identity) | ✅ Yes | ✅ Yes | For Azure VMs/containers with managed identity | | ActiveDirectoryServicePrincipal | ✅ Yes | ✅ Yes | Use client ID and secret or certificate | diff --git a/mssql_python/connection.py b/mssql_python/connection.py index c2dcb6daa..dde010a29 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -11,6 +11,7 @@ - Cursors are also cleaned up automatically when no longer referenced, to prevent memory leaks. """ import weakref +import re from mssql_python.cursor import Cursor from mssql_python.logging_config import get_logger, ENABLE_LOGGING from mssql_python.constants import ConstantsDDBC as ddbc_sql_const @@ -65,7 +66,7 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef connection_str, **kwargs ) self._attrs_before = attrs_before or {} - if "authentication" in self.connection_str.lower(): + if re.search(r"authentication", self.connection_str, re.IGNORECASE): connection_result = process_connection_string(self.connection_str) self.connection_str = connection_result[0] if connection_result[1]: From 7eaeebb7c825a031718c9bd31cd18f12b80759de Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Tue, 15 Jul 2025 12:29:24 +0530 Subject: [PATCH 03/12] Resolving comments --- mssql_python/auth.py | 28 +++--- mssql_python/connection.py | 5 ++ mssql_python/constants.py | 6 ++ tests/test_007_auth.py | 180 +++++++++++++++++++++++++++++++++++++ 4 files changed, 202 insertions(+), 17 deletions(-) create mode 100644 tests/test_007_auth.py diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 5b1da5a6e..c672bb745 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -8,15 +8,10 @@ import struct from typing import Tuple, Dict, Optional, Union from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.constants import AuthType logger = get_logger() -class AuthType: - """Constants for authentication types""" - INTERACTIVE = "activedirectoryinteractive" - DEVICE_CODE = "activedirectorydevicecode" - DEFAULT = "activedirectorydefault" - class AADAuth: """Handles Azure Active Directory authentication""" @@ -94,20 +89,19 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]: value_lower = value.lower() if key_lower == "authentication": - if value_lower == AuthType.INTERACTIVE: + # Check for supported authentication types and set auth_type accordingly + if value_lower == AuthType.INTERACTIVE.value: + # Interactive authentication (browser-based); only append parameter for non-Windows auth_type = "interactive" - if platform.system().lower() != "windows": - modified_parameters.append(param) - elif value_lower == AuthType.DEVICE_CODE: + if platform.system().lower() == "windows": + auth_type = None # Skip if on Windows + elif value_lower == AuthType.DEVICE_CODE.value: + # Device code authentication (for devices without browser) auth_type = "devicecode" - elif value_lower == AuthType.DEFAULT: + elif value_lower == AuthType.DEFAULT.value: + # Default authentication (uses DefaultAzureCredential) auth_type = "default" - else: - raise ValueError(f"Invalid authentication type: {value}. " - f"Supported types are: {AuthType.INTERACTIVE}, " - f"{AuthType.DEVICE_CODE}, {AuthType.DEFAULT}") - else: - modified_parameters.append(param) + modified_parameters.append(param) return modified_parameters, auth_type diff --git a/mssql_python/connection.py b/mssql_python/connection.py index dde010a29..e98612468 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -66,6 +66,11 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef connection_str, **kwargs ) self._attrs_before = attrs_before or {} + + # Check if the connection string contains authentication parameters + # This is important for processing the connection string correctly. + # If authentication is specified, it will be processed to handle + # different authentication types like interactive, device code, etc. if re.search(r"authentication", self.connection_str, re.IGNORECASE): connection_result = process_connection_string(self.connection_str) self.connection_str = connection_result[0] diff --git a/mssql_python/constants.py b/mssql_python/constants.py index aade503c7..81e60d37e 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -116,3 +116,9 @@ class ConstantsDDBC(Enum): SQL_C_WCHAR = -8 SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 + +class AuthType(Enum): + """Constants for authentication types""" + INTERACTIVE = "activedirectoryinteractive" + DEVICE_CODE = "activedirectorydevicecode" + DEFAULT = "activedirectorydefault" \ No newline at end of file diff --git a/tests/test_007_auth.py b/tests/test_007_auth.py new file mode 100644 index 000000000..ae3ebbc7e --- /dev/null +++ b/tests/test_007_auth.py @@ -0,0 +1,180 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +Tests for the auth module. +""" + +import pytest +import platform +import sys +from mssql_python.auth import ( + AuthType, + AADAuth, + process_auth_parameters, + remove_sensitive_params, + get_auth_token, + process_connection_string +) + +# Test data +SAMPLE_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6I" + +@pytest.fixture(autouse=True) +def setup_azure_identity(): + """Setup mock azure.identity module""" + class MockToken: + token = SAMPLE_TOKEN + + class MockDefaultAzureCredential: + def get_token(self, scope): + return MockToken() + + class MockDeviceCodeCredential: + def get_token(self, scope): + return MockToken() + + class MockInteractiveBrowserCredential: + def get_token(self, scope): + return MockToken() + + class MockIdentity: + DefaultAzureCredential = MockDefaultAzureCredential + DeviceCodeCredential = MockDeviceCodeCredential + InteractiveBrowserCredential = MockInteractiveBrowserCredential + + # Create mock azure module if it doesn't exist + if 'azure' not in sys.modules: + sys.modules['azure'] = type('MockAzure', (), {})() + + # Add identity module to azure + sys.modules['azure.identity'] = MockIdentity() + + yield + + # Cleanup + if 'azure.identity' in sys.modules: + del sys.modules['azure.identity'] + +class TestAuthType: + def test_auth_type_constants(self): + assert AuthType.INTERACTIVE == "activedirectoryinteractive" + assert AuthType.DEVICE_CODE == "activedirectorydevicecode" + assert AuthType.DEFAULT == "activedirectorydefault" + +class TestAADAuth: + def test_get_token_struct(self): + token_struct = AADAuth.get_token_struct(SAMPLE_TOKEN) + assert isinstance(token_struct, bytes) + assert len(token_struct) > 4 + + def test_get_default_token(self): + token_struct = AADAuth.get_default_token() + assert isinstance(token_struct, bytes) + + def test_get_device_code_token(self): + token_struct = AADAuth.get_device_code_token() + assert isinstance(token_struct, bytes) + + def test_get_interactive_token(self): + token_struct = AADAuth.get_interactive_token() + assert isinstance(token_struct, bytes) + +class TestProcessAuthParameters: + def test_empty_parameters(self): + modified_params, auth_type = process_auth_parameters([]) + assert modified_params == [] + assert auth_type is None + + def test_interactive_auth_windows(self, monkeypatch): + monkeypatch.setattr(platform, "system", lambda: "Windows") + params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert "Authentication=ActiveDirectoryInteractive" not in modified_params + assert auth_type == "interactive" + + def test_interactive_auth_non_windows(self, monkeypatch): + monkeypatch.setattr(platform, "system", lambda: "Darwin") + params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert "Authentication=ActiveDirectoryInteractive" in modified_params + assert auth_type == "interactive" + + def test_device_code_auth(self): + params = ["Authentication=ActiveDirectoryDeviceCode", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert auth_type == "devicecode" + + def test_default_auth(self): + params = ["Authentication=ActiveDirectoryDefault", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert auth_type == "default" + +class TestRemoveSensitiveParams: + def test_remove_sensitive_parameters(self): + params = [ + "Server=test", + "UID=user", + "PWD=password", + "Encrypt=yes", + "TrustServerCertificate=yes", + "Authentication=ActiveDirectoryDefault", + "Database=testdb" + ] + filtered_params = remove_sensitive_params(params) + assert "Server=test" in filtered_params + assert "Database=testdb" in filtered_params + assert "UID=user" not in filtered_params + assert "PWD=password" not in filtered_params + assert "Encrypt=yes" not in filtered_params + assert "TrustServerCertificate=yes" not in filtered_params + assert "Authentication=ActiveDirectoryDefault" not in filtered_params + +class TestProcessConnectionString: + def test_process_connection_string_with_default_auth(self): + conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" + result_str, attrs = process_connection_string(conn_str) + + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert attrs is not None + assert 1256 in attrs + assert isinstance(attrs[1256], bytes) + + def test_process_connection_string_no_auth(self): + conn_str = "Server=test;Database=testdb;UID=user;PWD=password" + result_str, attrs = process_connection_string(conn_str) + + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert "UID=user" in result_str + assert "PWD=password" in result_str + assert attrs is None + + def test_process_connection_string_interactive_non_windows(self, monkeypatch): + monkeypatch.setattr(platform, "system", lambda: "Darwin") + conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb" + result_str, attrs = process_connection_string(conn_str) + + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert attrs is not None + assert 1256 in attrs + assert isinstance(attrs[1256], bytes) + +def test_error_handling(): + # Empty string should raise ValueError + with pytest.raises(ValueError, match="Connection string cannot be empty"): + process_connection_string("") + + # Invalid connection string should raise ValueError + with pytest.raises(ValueError, match="Invalid connection string format"): + process_connection_string("InvalidConnectionString") + + # Invalid auth type should raise ValueError + with pytest.raises(ValueError, match="Invalid authentication type"): + conn_str = "Server=test;Authentication=InvalidAuth" + process_connection_string(conn_str) + + # Test non-string input + with pytest.raises(ValueError, match="Connection string must be a string"): + process_connection_string(None) \ No newline at end of file From d8b5bfd6037f84967cf018cd95d71e310eeec1e2 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Tue, 15 Jul 2025 21:52:40 +0530 Subject: [PATCH 04/12] Adding comments --- mssql_python/auth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index c672bb745..0729e6508 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -124,6 +124,8 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: return AADAuth.get_default_token() elif auth_type == "devicecode": return AADAuth.get_device_code_token() + # If interactive authentication is requested, use InteractiveBrowserCredential + # but only if not on Windows, since in Windows: AADInteractive is supported. elif auth_type == "interactive" and platform.system().lower() != "windows": return AADAuth.get_interactive_token() return None From 9262f73a4d79589dcc347a7b881dc36fd9e8940d Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 16 Jul 2025 10:21:02 +0530 Subject: [PATCH 05/12] Altering testcases --- mssql_python/auth.py | 4 ++-- tests/test_007_auth.py | 15 +++++---------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 0729e6508..58a3110d5 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -92,9 +92,9 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]: # Check for supported authentication types and set auth_type accordingly if value_lower == AuthType.INTERACTIVE.value: # Interactive authentication (browser-based); only append parameter for non-Windows - auth_type = "interactive" if platform.system().lower() == "windows": - auth_type = None # Skip if on Windows + continue # Skip adding this parameter for Windows + auth_type = "interactive" elif value_lower == AuthType.DEVICE_CODE.value: # Device code authentication (for devices without browser) auth_type = "devicecode" diff --git a/tests/test_007_auth.py b/tests/test_007_auth.py index ae3ebbc7e..52cea8178 100644 --- a/tests/test_007_auth.py +++ b/tests/test_007_auth.py @@ -8,13 +8,13 @@ import platform import sys from mssql_python.auth import ( - AuthType, AADAuth, process_auth_parameters, remove_sensitive_params, get_auth_token, process_connection_string ) +from mssql_python.constants import AuthType # Test data SAMPLE_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6I" @@ -57,9 +57,9 @@ class MockIdentity: class TestAuthType: def test_auth_type_constants(self): - assert AuthType.INTERACTIVE == "activedirectoryinteractive" - assert AuthType.DEVICE_CODE == "activedirectorydevicecode" - assert AuthType.DEFAULT == "activedirectorydefault" + assert AuthType.INTERACTIVE.value == "activedirectoryinteractive" + assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" + assert AuthType.DEFAULT.value == "activedirectorydefault" class TestAADAuth: def test_get_token_struct(self): @@ -90,7 +90,7 @@ def test_interactive_auth_windows(self, monkeypatch): params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] modified_params, auth_type = process_auth_parameters(params) assert "Authentication=ActiveDirectoryInteractive" not in modified_params - assert auth_type == "interactive" + assert auth_type == None def test_interactive_auth_non_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Darwin") @@ -170,11 +170,6 @@ def test_error_handling(): with pytest.raises(ValueError, match="Invalid connection string format"): process_connection_string("InvalidConnectionString") - # Invalid auth type should raise ValueError - with pytest.raises(ValueError, match="Invalid authentication type"): - conn_str = "Server=test;Authentication=InvalidAuth" - process_connection_string(conn_str) - # Test non-string input with pytest.raises(ValueError, match="Connection string must be a string"): process_connection_string(None) \ No newline at end of file From be1fba4d4ff1217aa21ed519a32d5c7993f7649a Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Wed, 16 Jul 2025 12:27:06 +0530 Subject: [PATCH 06/12] removed ENABLE_LOGGING --- mssql_python/auth.py | 2 +- mssql_python/connection.py | 18 +++++++++--------- mssql_python/cursor.py | 10 +++++----- mssql_python/exceptions.py | 6 +++--- mssql_python/helpers.py | 4 ++-- mssql_python/logging_config.py | 8 ++++---- tests/test_006_logging.py | 6 +++--- 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 58a3110d5..d67bc5efc 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -7,7 +7,7 @@ import platform import struct from typing import Tuple, Dict, Optional, Union -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.logging_config import get_logger from mssql_python.constants import AuthType logger = get_logger() diff --git a/mssql_python/connection.py b/mssql_python/connection.py index e98612468..71e8d209d 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -13,7 +13,7 @@ import weakref import re from mssql_python.cursor import Cursor -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.logging_config import get_logger from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import add_driver_to_connection_str, check_error from mssql_python import ddbc_bindings @@ -126,7 +126,7 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st continue conn_str += f"{key}={value};" - if ENABLE_LOGGING: + if logger: logger.info("Final connection string: %s", conn_str) return conn_str @@ -150,7 +150,7 @@ def autocommit(self, value: bool) -> None: None """ self.setautocommit(value) - if ENABLE_LOGGING: + if logger: logger.info("Autocommit mode set to %s.", value) def setautocommit(self, value: bool = True) -> None: @@ -206,7 +206,7 @@ def commit(self) -> None: """ # Commit the current transaction self._conn.commit() - if ENABLE_LOGGING: + if logger: logger.info("Transaction committed successfully.") def rollback(self) -> None: @@ -222,7 +222,7 @@ def rollback(self) -> None: """ # Roll back the current transaction self._conn.rollback() - if ENABLE_LOGGING: + if logger: logger.info("Transaction rolled back successfully.") def close(self) -> None: @@ -255,11 +255,11 @@ def close(self) -> None: except Exception as e: # Collect errors but continue closing other cursors close_errors.append(f"Error closing cursor: {e}") - if ENABLE_LOGGING: + if logger: logger.warning(f"Error closing cursor: {e}") # If there were errors closing cursors, log them but continue - if close_errors and ENABLE_LOGGING: + if close_errors and logger: logger.warning(f"Encountered {len(close_errors)} errors while closing cursors") # Clear the cursor set explicitly to release any internal references @@ -271,7 +271,7 @@ def close(self) -> None: self._conn.close() self._conn = None except Exception as e: - if ENABLE_LOGGING: + if logger: logger.error(f"Error closing database connection: {e}") # Re-raise the connection close error as it's more critical raise @@ -279,5 +279,5 @@ def close(self) -> None: # Always mark as closed, even if there were errors self._closed = True - if ENABLE_LOGGING: + if logger: logger.info("Connection closed successfully.") diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 6e2efc9e7..d30efcdbe 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -15,7 +15,7 @@ from typing import List, Union from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import check_error -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.logging_config import get_logger from mssql_python import ddbc_bindings from .row import Row @@ -431,7 +431,7 @@ def _reset_cursor(self) -> None: if self.hstmt: self.hstmt.free() self.hstmt = None - if ENABLE_LOGGING: + if logger: logger.debug("SQLFreeHandle succeeded") # Reinitialize the statement handle self._initialize_cursor() @@ -449,7 +449,7 @@ def close(self) -> None: if self.hstmt: self.hstmt.free() self.hstmt = None - if ENABLE_LOGGING: + if logger: logger.debug("SQLFreeHandle succeeded") self.closed = True @@ -584,7 +584,7 @@ def execute( # Executing a new statement. Reset is_stmt_prepared to false self.is_stmt_prepared = [False] - if ENABLE_LOGGING: + if logger: logger.debug("Executing query: %s", operation) for i, param in enumerate(parameters): logger.debug( @@ -637,7 +637,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: total_rowcount = 0 for parameters in seq_of_parameters: parameters = list(parameters) - if ENABLE_LOGGING: + if logger: logger.info("Executing query with parameters: %s", parameters) prepare_stmt = first_execution first_execution = False diff --git a/mssql_python/exceptions.py b/mssql_python/exceptions.py index c2307a5f5..308a85690 100644 --- a/mssql_python/exceptions.py +++ b/mssql_python/exceptions.py @@ -4,7 +4,7 @@ This module contains custom exception classes for the mssql_python package. These classes are used to raise exceptions when an error occurs while executing a query. """ -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.logging_config import get_logger logger = get_logger() @@ -621,7 +621,7 @@ def truncate_error_message(error_message: str) -> str: string_third = string_second[string_second.index("]") + 1 :] return string_first + string_third except Exception as e: - if ENABLE_LOGGING: + if logger: logger.error("Error while truncating error message: %s",e) return error_message @@ -641,7 +641,7 @@ def raise_exception(sqlstate: str, ddbc_error: str) -> None: """ exception_class = sqlstate_to_exception(sqlstate, ddbc_error) if exception_class: - if ENABLE_LOGGING: + if logger: logger.error(exception_class) raise exception_class raise DatabaseError( diff --git a/mssql_python/helpers.py b/mssql_python/helpers.py index cffb06467..cecfc39dc 100644 --- a/mssql_python/helpers.py +++ b/mssql_python/helpers.py @@ -6,7 +6,7 @@ from mssql_python import ddbc_bindings from mssql_python.exceptions import raise_exception -from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.logging_config import get_logger import platform from pathlib import Path from mssql_python.ddbc_bindings import normalize_architecture @@ -73,7 +73,7 @@ def check_error(handle_type, handle, ret): """ if ret < 0: error_info = ddbc_bindings.DDBCSQLCheckError(handle_type, handle, ret) - if ENABLE_LOGGING: + if logger: logger.error("Error: %s", error_info.ddbcErrorMsg) raise_exception(error_info.sqlState, error_info.ddbcErrorMsg) diff --git a/mssql_python/logging_config.py b/mssql_python/logging_config.py index d0952724f..d8d45cbca 100644 --- a/mssql_python/logging_config.py +++ b/mssql_python/logging_config.py @@ -9,7 +9,7 @@ import os import sys -ENABLE_LOGGING = False +logger = False def setup_logging(mode="file", log_level=logging.DEBUG): @@ -23,8 +23,8 @@ def setup_logging(mode="file", log_level=logging.DEBUG): mode (str): The logging mode ('file' or 'stdout'). log_level (int): The logging level (default: logging.DEBUG). """ - global ENABLE_LOGGING - ENABLE_LOGGING = True + global logger + logger = True # Create a logger for mssql_python module logger = logging.getLogger(__name__) @@ -60,6 +60,6 @@ def get_logger(): Returns: logging.Logger: The logger instance. """ - if not ENABLE_LOGGING: + if not logger: return None return logging.getLogger(__name__) diff --git a/tests/test_006_logging.py b/tests/test_006_logging.py index e78c29eb5..feb637b02 100644 --- a/tests/test_006_logging.py +++ b/tests/test_006_logging.py @@ -1,7 +1,7 @@ import logging import os import pytest -from mssql_python.logging_config import setup_logging, get_logger, ENABLE_LOGGING +from mssql_python.logging_config import setup_logging, get_logger, logger def get_log_file_path(): repo_root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -19,7 +19,7 @@ def cleanup(): log_file_path = get_log_file_path() if os.path.exists(log_file_path): os.remove(log_file_path) - ENABLE_LOGGING = False + logger = False # Perform cleanup before the test cleanup() yield @@ -31,7 +31,7 @@ def test_no_logging(cleanup_logger): try: logger = get_logger() assert logger is None - assert ENABLE_LOGGING == False + assert logger == False except Exception as e: pytest.fail(f"Logging not off by default. Error: {e}") From 6827dbce7d7f8be14fc1fa9e1f0f5651047f9007 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Wed, 16 Jul 2025 12:36:49 +0530 Subject: [PATCH 07/12] fixed conflicts again --- tests/test_006_logging.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_006_logging.py b/tests/test_006_logging.py index feb637b02..a32185d6a 100644 --- a/tests/test_006_logging.py +++ b/tests/test_006_logging.py @@ -31,7 +31,6 @@ def test_no_logging(cleanup_logger): try: logger = get_logger() assert logger is None - assert logger == False except Exception as e: pytest.fail(f"Logging not off by default. Error: {e}") From f093eaf3d7b605e715b82d018116249e1be70926 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Wed, 16 Jul 2025 12:52:26 +0530 Subject: [PATCH 08/12] restored others to this branch - reverting back --- mssql_python/connection.py | 18 +++++++++--------- mssql_python/constants.py | 6 ------ mssql_python/cursor.py | 10 +++++----- mssql_python/exceptions.py | 6 +++--- mssql_python/helpers.py | 4 ++-- mssql_python/logging_config.py | 8 ++++---- tests/test_006_logging.py | 5 +++-- 7 files changed, 26 insertions(+), 31 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 71e8d209d..c66247414 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -13,7 +13,7 @@ import weakref import re from mssql_python.cursor import Cursor -from mssql_python.logging_config import get_logger +from mssql_python.logging_config import get_logger, ENABLE_LOGGING from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import add_driver_to_connection_str, check_error from mssql_python import ddbc_bindings @@ -126,7 +126,7 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st continue conn_str += f"{key}={value};" - if logger: + if ENABLE_LOGGING logger.info("Final connection string: %s", conn_str) return conn_str @@ -150,7 +150,7 @@ def autocommit(self, value: bool) -> None: None """ self.setautocommit(value) - if logger: + if ENABLE_LOGGING logger.info("Autocommit mode set to %s.", value) def setautocommit(self, value: bool = True) -> None: @@ -206,7 +206,7 @@ def commit(self) -> None: """ # Commit the current transaction self._conn.commit() - if logger: + if ENABLE_LOGGING logger.info("Transaction committed successfully.") def rollback(self) -> None: @@ -222,7 +222,7 @@ def rollback(self) -> None: """ # Roll back the current transaction self._conn.rollback() - if logger: + if ENABLE_LOGGING logger.info("Transaction rolled back successfully.") def close(self) -> None: @@ -255,11 +255,11 @@ def close(self) -> None: except Exception as e: # Collect errors but continue closing other cursors close_errors.append(f"Error closing cursor: {e}") - if logger: + if ENABLE_LOGGING logger.warning(f"Error closing cursor: {e}") # If there were errors closing cursors, log them but continue - if close_errors and logger: + if close_errors and ENABLE_LOGGING logger.warning(f"Encountered {len(close_errors)} errors while closing cursors") # Clear the cursor set explicitly to release any internal references @@ -271,7 +271,7 @@ def close(self) -> None: self._conn.close() self._conn = None except Exception as e: - if logger: + if ENABLE_LOGGING logger.error(f"Error closing database connection: {e}") # Re-raise the connection close error as it's more critical raise @@ -279,5 +279,5 @@ def close(self) -> None: # Always mark as closed, even if there were errors self._closed = True - if logger: + if ENABLE_LOGGING logger.info("Connection closed successfully.") diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 81e60d37e..aade503c7 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -116,9 +116,3 @@ class ConstantsDDBC(Enum): SQL_C_WCHAR = -8 SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 - -class AuthType(Enum): - """Constants for authentication types""" - INTERACTIVE = "activedirectoryinteractive" - DEVICE_CODE = "activedirectorydevicecode" - DEFAULT = "activedirectorydefault" \ No newline at end of file diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index d30efcdbe..6e2efc9e7 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -15,7 +15,7 @@ from typing import List, Union from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import check_error -from mssql_python.logging_config import get_logger +from mssql_python.logging_config import get_logger, ENABLE_LOGGING from mssql_python import ddbc_bindings from .row import Row @@ -431,7 +431,7 @@ def _reset_cursor(self) -> None: if self.hstmt: self.hstmt.free() self.hstmt = None - if logger: + if ENABLE_LOGGING: logger.debug("SQLFreeHandle succeeded") # Reinitialize the statement handle self._initialize_cursor() @@ -449,7 +449,7 @@ def close(self) -> None: if self.hstmt: self.hstmt.free() self.hstmt = None - if logger: + if ENABLE_LOGGING: logger.debug("SQLFreeHandle succeeded") self.closed = True @@ -584,7 +584,7 @@ def execute( # Executing a new statement. Reset is_stmt_prepared to false self.is_stmt_prepared = [False] - if logger: + if ENABLE_LOGGING: logger.debug("Executing query: %s", operation) for i, param in enumerate(parameters): logger.debug( @@ -637,7 +637,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: total_rowcount = 0 for parameters in seq_of_parameters: parameters = list(parameters) - if logger: + if ENABLE_LOGGING: logger.info("Executing query with parameters: %s", parameters) prepare_stmt = first_execution first_execution = False diff --git a/mssql_python/exceptions.py b/mssql_python/exceptions.py index 308a85690..c2307a5f5 100644 --- a/mssql_python/exceptions.py +++ b/mssql_python/exceptions.py @@ -4,7 +4,7 @@ This module contains custom exception classes for the mssql_python package. These classes are used to raise exceptions when an error occurs while executing a query. """ -from mssql_python.logging_config import get_logger +from mssql_python.logging_config import get_logger, ENABLE_LOGGING logger = get_logger() @@ -621,7 +621,7 @@ def truncate_error_message(error_message: str) -> str: string_third = string_second[string_second.index("]") + 1 :] return string_first + string_third except Exception as e: - if logger: + if ENABLE_LOGGING: logger.error("Error while truncating error message: %s",e) return error_message @@ -641,7 +641,7 @@ def raise_exception(sqlstate: str, ddbc_error: str) -> None: """ exception_class = sqlstate_to_exception(sqlstate, ddbc_error) if exception_class: - if logger: + if ENABLE_LOGGING: logger.error(exception_class) raise exception_class raise DatabaseError( diff --git a/mssql_python/helpers.py b/mssql_python/helpers.py index cecfc39dc..cffb06467 100644 --- a/mssql_python/helpers.py +++ b/mssql_python/helpers.py @@ -6,7 +6,7 @@ from mssql_python import ddbc_bindings from mssql_python.exceptions import raise_exception -from mssql_python.logging_config import get_logger +from mssql_python.logging_config import get_logger, ENABLE_LOGGING import platform from pathlib import Path from mssql_python.ddbc_bindings import normalize_architecture @@ -73,7 +73,7 @@ def check_error(handle_type, handle, ret): """ if ret < 0: error_info = ddbc_bindings.DDBCSQLCheckError(handle_type, handle, ret) - if logger: + if ENABLE_LOGGING: logger.error("Error: %s", error_info.ddbcErrorMsg) raise_exception(error_info.sqlState, error_info.ddbcErrorMsg) diff --git a/mssql_python/logging_config.py b/mssql_python/logging_config.py index d8d45cbca..d0952724f 100644 --- a/mssql_python/logging_config.py +++ b/mssql_python/logging_config.py @@ -9,7 +9,7 @@ import os import sys -logger = False +ENABLE_LOGGING = False def setup_logging(mode="file", log_level=logging.DEBUG): @@ -23,8 +23,8 @@ def setup_logging(mode="file", log_level=logging.DEBUG): mode (str): The logging mode ('file' or 'stdout'). log_level (int): The logging level (default: logging.DEBUG). """ - global logger - logger = True + global ENABLE_LOGGING + ENABLE_LOGGING = True # Create a logger for mssql_python module logger = logging.getLogger(__name__) @@ -60,6 +60,6 @@ def get_logger(): Returns: logging.Logger: The logger instance. """ - if not logger: + if not ENABLE_LOGGING: return None return logging.getLogger(__name__) diff --git a/tests/test_006_logging.py b/tests/test_006_logging.py index a32185d6a..e78c29eb5 100644 --- a/tests/test_006_logging.py +++ b/tests/test_006_logging.py @@ -1,7 +1,7 @@ import logging import os import pytest -from mssql_python.logging_config import setup_logging, get_logger, logger +from mssql_python.logging_config import setup_logging, get_logger, ENABLE_LOGGING def get_log_file_path(): repo_root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -19,7 +19,7 @@ def cleanup(): log_file_path = get_log_file_path() if os.path.exists(log_file_path): os.remove(log_file_path) - logger = False + ENABLE_LOGGING = False # Perform cleanup before the test cleanup() yield @@ -31,6 +31,7 @@ def test_no_logging(cleanup_logger): try: logger = get_logger() assert logger is None + assert ENABLE_LOGGING == False except Exception as e: pytest.fail(f"Logging not off by default. Error: {e}") From d260e5261168f6489366a57a4feb0d33fa7c5bb1 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Wed, 16 Jul 2025 12:53:23 +0530 Subject: [PATCH 09/12] restoring constants and setup --- mssql_python/constants.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mssql_python/constants.py b/mssql_python/constants.py index aade503c7..81e60d37e 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -116,3 +116,9 @@ class ConstantsDDBC(Enum): SQL_C_WCHAR = -8 SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 + +class AuthType(Enum): + """Constants for authentication types""" + INTERACTIVE = "activedirectoryinteractive" + DEVICE_CODE = "activedirectorydevicecode" + DEFAULT = "activedirectorydefault" \ No newline at end of file From fdadbca87ae35854e5a2285c64d4a425d84f8dcc Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 16 Jul 2025 14:02:33 +0530 Subject: [PATCH 10/12] Adding setup dependencies --- setup.py | 4 ++++ tests/{test_007_auth.py => test_008_auth.py} | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) rename tests/{test_007_auth.py => test_008_auth.py} (98%) diff --git a/setup.py b/setup.py index 273ea0867..d5e855ede 100644 --- a/setup.py +++ b/setup.py @@ -100,6 +100,10 @@ def finalize_options(self): include_package_data=True, # Requires >= Python 3.10 python_requires='>=3.10', + # Add dependencies + install_requires=[ + 'azure-identity>=1.12.0', # Azure authentication library + ], classifiers=[ 'Operating System :: Microsoft :: Windows', 'Operating System :: MacOS', diff --git a/tests/test_007_auth.py b/tests/test_008_auth.py similarity index 98% rename from tests/test_007_auth.py rename to tests/test_008_auth.py index 52cea8178..1ccb2d7d8 100644 --- a/tests/test_007_auth.py +++ b/tests/test_008_auth.py @@ -15,9 +15,9 @@ process_connection_string ) from mssql_python.constants import AuthType +import secrets -# Test data -SAMPLE_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6I" +SAMPLE_TOKEN = secrets.token_hex(44) @pytest.fixture(autouse=True) def setup_azure_identity(): From 992aa6d5a80209e88951ce952d371fed900338b4 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 16 Jul 2025 14:14:24 +0530 Subject: [PATCH 11/12] Resolving conflicts --- mssql_python/connection.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 230978795..8456ef92d 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -126,7 +126,7 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st continue conn_str += f"{key}={value};" - if ENABLE_LOGGING + if ENABLE_LOGGING: logger.info("Final connection string: %s", conn_str) return conn_str @@ -150,7 +150,7 @@ def autocommit(self, value: bool) -> None: None """ self.setautocommit(value) - if ENABLE_LOGGING + if ENABLE_LOGGING: logger.info("Autocommit mode set to %s.", value) def setautocommit(self, value: bool = True) -> None: @@ -205,7 +205,7 @@ def commit(self) -> None: """ # Commit the current transaction self._conn.commit() - if ENABLE_LOGGING + if ENABLE_LOGGING: logger.info("Transaction committed successfully.") def rollback(self) -> None: @@ -221,7 +221,7 @@ def rollback(self) -> None: """ # Roll back the current transaction self._conn.rollback() - if ENABLE_LOGGING + if ENABLE_LOGGING: logger.info("Transaction rolled back successfully.") def close(self) -> None: @@ -254,11 +254,11 @@ def close(self) -> None: except Exception as e: # Collect errors but continue closing other cursors close_errors.append(f"Error closing cursor: {e}") - if ENABLE_LOGGING + if ENABLE_LOGGING: logger.warning(f"Error closing cursor: {e}") # If there were errors closing cursors, log them but continue - if close_errors and ENABLE_LOGGING + if close_errors and ENABLE_LOGGING: logger.warning(f"Encountered {len(close_errors)} errors while closing cursors") # Clear the cursor set explicitly to release any internal references @@ -270,7 +270,7 @@ def close(self) -> None: self._conn.close() self._conn = None except Exception as e: - if ENABLE_LOGGING + if ENABLE_LOGGING: logger.error(f"Error closing database connection: {e}") # Re-raise the connection close error as it's more critical raise @@ -278,7 +278,7 @@ def close(self) -> None: # Always mark as closed, even if there were errors self._closed = True - if ENABLE_LOGGING + if ENABLE_LOGGING: logger.info("Connection closed successfully.") def __del__(self): From 2c61a41dc6987cb0c5e0ac3a5d1c6da85ab0825a Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 16 Jul 2025 20:14:23 +0530 Subject: [PATCH 12/12] Resolving comments --- mssql_python/auth.py | 77 ++++++++++++++++++++---------------------- tests/test_008_auth.py | 66 +++++++++++++++++++++++++++++++----- 2 files changed, 93 insertions(+), 50 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index d67bc5efc..30efd6d99 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -7,7 +7,7 @@ import platform import struct from typing import Tuple, Dict, Optional, Union -from mssql_python.logging_config import get_logger +from mssql_python.logging_config import get_logger, ENABLE_LOGGING from mssql_python.constants import AuthType logger = get_logger() @@ -22,42 +22,38 @@ def get_token_struct(token: str) -> bytes: return struct.pack(f" bytes: - """Get token using DefaultAzureCredential""" - from azure.identity import DefaultAzureCredential - - try: - # DefaultAzureCredential will automatically use the best available method - # based on the environment (e.g., managed identity, environment variables) - credential = DefaultAzureCredential() - token = credential.get_token("https://database.windows.net/.default").token - return AADAuth.get_token_struct(token) - except Exception as e: - raise RuntimeError(f"Failed to create DefaultAzureCredential: {e}") - - @staticmethod - def get_device_code_token() -> bytes: - """Get token using DeviceCodeCredential""" - from azure.identity import DeviceCodeCredential - - try: - credential = DeviceCodeCredential() - token = credential.get_token("https://database.windows.net/.default").token - return AADAuth.get_token_struct(token) - except Exception as e: - raise RuntimeError(f"Failed to create DeviceCodeCredential: {e}") - - @staticmethod - def get_interactive_token() -> bytes: - """Get token using InteractiveBrowserCredential""" - from azure.identity import InteractiveBrowserCredential + def get_token(auth_type: str) -> bytes: + """Get token using the specified authentication type""" + from azure.identity import ( + DefaultAzureCredential, + DeviceCodeCredential, + InteractiveBrowserCredential + ) + from azure.core.exceptions import ClientAuthenticationError + + # Mapping of auth types to credential classes + credential_map = { + "default": DefaultAzureCredential, + "devicecode": DeviceCodeCredential, + "interactive": InteractiveBrowserCredential, + } + + credential_class = credential_map[auth_type] try: - credential = InteractiveBrowserCredential() + credential = credential_class() token = credential.get_token("https://database.windows.net/.default").token return AADAuth.get_token_struct(token) + except ClientAuthenticationError as e: + # Re-raise with more specific context about Azure AD authentication failure + raise RuntimeError( + f"Azure AD authentication failed for {credential_class.__name__}: {e}. " + f"This could be due to invalid credentials, missing environment variables, " + f"user cancellation, network issues, or unsupported configuration." + ) from e except Exception as e: - raise RuntimeError(f"Failed to create InteractiveBrowserCredential: {e}") + # Catch any other unexpected exceptions + raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]: """ @@ -120,15 +116,14 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: if not auth_type: return None - if auth_type == "default": - return AADAuth.get_default_token() - elif auth_type == "devicecode": - return AADAuth.get_device_code_token() - # If interactive authentication is requested, use InteractiveBrowserCredential - # but only if not on Windows, since in Windows: AADInteractive is supported. - elif auth_type == "interactive" and platform.system().lower() != "windows": - return AADAuth.get_interactive_token() - return None + # Handle platform-specific logic for interactive auth + if auth_type == "interactive" and platform.system().lower() == "windows": + return None # Let Windows handle AADInteractive natively + + try: + return AADAuth.get_token(auth_type) + except (ValueError, RuntimeError): + return None def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict]]: """ diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 1ccb2d7d8..2c3c1a0ad 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -37,23 +37,34 @@ class MockInteractiveBrowserCredential: def get_token(self, scope): return MockToken() + # Mock ClientAuthenticationError + class MockClientAuthenticationError(Exception): + pass + class MockIdentity: DefaultAzureCredential = MockDefaultAzureCredential DeviceCodeCredential = MockDeviceCodeCredential InteractiveBrowserCredential = MockInteractiveBrowserCredential + class MockCore: + class exceptions: + ClientAuthenticationError = MockClientAuthenticationError + # Create mock azure module if it doesn't exist if 'azure' not in sys.modules: sys.modules['azure'] = type('MockAzure', (), {})() - # Add identity module to azure + # Add identity and core modules to azure sys.modules['azure.identity'] = MockIdentity() + sys.modules['azure.core'] = MockCore() + sys.modules['azure.core.exceptions'] = MockCore.exceptions() yield # Cleanup - if 'azure.identity' in sys.modules: - del sys.modules['azure.identity'] + for module in ['azure.identity', 'azure.core', 'azure.core.exceptions']: + if module in sys.modules: + del sys.modules[module] class TestAuthType: def test_auth_type_constants(self): @@ -67,18 +78,55 @@ def test_get_token_struct(self): assert isinstance(token_struct, bytes) assert len(token_struct) > 4 - def test_get_default_token(self): - token_struct = AADAuth.get_default_token() + def test_get_token_default(self): + token_struct = AADAuth.get_token("default") assert isinstance(token_struct, bytes) - def test_get_device_code_token(self): - token_struct = AADAuth.get_device_code_token() + def test_get_token_device_code(self): + token_struct = AADAuth.get_token("devicecode") assert isinstance(token_struct, bytes) - def test_get_interactive_token(self): - token_struct = AADAuth.get_interactive_token() + def test_get_token_interactive(self): + token_struct = AADAuth.get_token("interactive") assert isinstance(token_struct, bytes) + def test_get_token_credential_mapping(self): + # Test that all supported auth types work + supported_types = ["default", "devicecode", "interactive"] + for auth_type in supported_types: + token_struct = AADAuth.get_token(auth_type) + assert isinstance(token_struct, bytes) + assert len(token_struct) > 4 + + def test_get_token_client_authentication_error(self): + """Test that ClientAuthenticationError is properly handled""" + from azure.core.exceptions import ClientAuthenticationError + + # Create a mock credential that raises ClientAuthenticationError + class MockFailingCredential: + def get_token(self, scope): + raise ClientAuthenticationError("Mock authentication failed") + + # Use monkeypatch to mock the credential creation + def mock_get_token_failing(auth_type): + from azure.core.exceptions import ClientAuthenticationError + if auth_type == "default": + try: + credential = MockFailingCredential() + token = credential.get_token("https://database.windows.net/.default").token + return AADAuth.get_token_struct(token) + except ClientAuthenticationError as e: + raise RuntimeError( + f"Azure AD authentication failed for MockFailingCredential: {e}. " + f"This could be due to invalid credentials, missing environment variables, " + f"user cancellation, network issues, or unsupported configuration." + ) from e + else: + return AADAuth.get_token(auth_type) + + with pytest.raises(RuntimeError, match="Azure AD authentication failed"): + mock_get_token_failing("default") + class TestProcessAuthParameters: def test_empty_parameters(self): modified_params, auth_type = process_auth_parameters([])