diff --git a/README.md b/README.md index 8359ccb64..32f9eaebb 100644 --- a/README.md +++ b/README.md @@ -48,17 +48,23 @@ By adhering to the DB API 2.0 specification, the mssql-python module ensures com ### Support for Microsoft Entra ID Authentication -The Microsoft mssql-python driver enables Python applications to connect to Microsoft SQL Server, Azure SQL Database, or Azure SQL Managed Instance using Microsoft Entra ID identities. It supports various authentication methods, including username and password, Microsoft Entra managed identity, and Integrated Windows Authentication in a federated, domain-joined environment. Additionally, the driver supports Microsoft Entra interactive authentication and Microsoft Entra managed identity authentication for both system-assigned and user-assigned managed identities. +The Microsoft mssql-python driver enables Python applications to connect to Microsoft SQL Server, Azure SQL Database, or Azure SQL Managed Instance using Microsoft Entra ID identities. It supports a variety of authentication methods, including username and password, Microsoft Entra managed identity (system-assigned and user-assigned), Integrated Windows Authentication in a federated, domain-joined environment, interactive authentication via browser, device code flow for environments without browser access, and the default authentication method based on environment and configuration. This flexibility allows developers to choose the most suitable authentication approach for their deployment scenario. EntraID authentication is now fully supported on MacOS and Linux but with certain limitations as mentioned in the table: | Authentication Method | Windows Support | macOS/Linux Support | Notes | |----------------------|----------------|---------------------|-------| | ActiveDirectoryPassword | ✅ Yes | ✅ Yes | Username/password-based authentication | -| ActiveDirectoryInteractive | ✅ Yes | ❌ No | Only works on Windows | +| ActiveDirectoryInteractive | ✅ Yes | ✅ Yes | Interactive login via browser; requires user interaction | | ActiveDirectoryMSI (Managed Identity) | ✅ Yes | ✅ Yes | For Azure VMs/containers with managed identity | | ActiveDirectoryServicePrincipal | ✅ Yes | ✅ Yes | Use client ID and secret or certificate | | ActiveDirectoryIntegrated | ✅ Yes | ❌ No | Only works on Windows (requires Kerberos/SSPI) | +| ActiveDirectoryDeviceCode | ✅ Yes | ✅ Yes | Device code flow for authentication; suitable for environments without browser access | +| ActiveDirectoryDefault | ✅ Yes | ✅ Yes | Uses default authentication method based on environment and configuration | + +**NOTE**: For using Access Token, the connection string *must not* contain `UID`, `PWD`, `Authentication`, or `Trusted_Connection` keywords. + +**NOTE**: For using ActiveDirectoryDeviceCode, make sure to specify a `Connect Timeout` that provides enough time to go through the device code flow authentication process. ### Enhanced Pythonic Features diff --git a/mssql_python/auth.py b/mssql_python/auth.py new file mode 100644 index 000000000..30efd6d99 --- /dev/null +++ b/mssql_python/auth.py @@ -0,0 +1,163 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +This module handles authentication for the mssql_python package. +""" + +import platform +import struct +from typing import Tuple, Dict, Optional, Union +from mssql_python.logging_config import get_logger, ENABLE_LOGGING +from mssql_python.constants import AuthType + +logger = get_logger() + +class AADAuth: + """Handles Azure Active Directory authentication""" + + @staticmethod + def get_token_struct(token: str) -> bytes: + """Convert token to SQL Server compatible format""" + token_bytes = token.encode("UTF-16-LE") + return struct.pack(f" bytes: + """Get token using the specified authentication type""" + from azure.identity import ( + DefaultAzureCredential, + DeviceCodeCredential, + InteractiveBrowserCredential + ) + from azure.core.exceptions import ClientAuthenticationError + + # Mapping of auth types to credential classes + credential_map = { + "default": DefaultAzureCredential, + "devicecode": DeviceCodeCredential, + "interactive": InteractiveBrowserCredential, + } + + credential_class = credential_map[auth_type] + + try: + credential = credential_class() + token = credential.get_token("https://database.windows.net/.default").token + return AADAuth.get_token_struct(token) + except ClientAuthenticationError as e: + # Re-raise with more specific context about Azure AD authentication failure + raise RuntimeError( + f"Azure AD authentication failed for {credential_class.__name__}: {e}. " + f"This could be due to invalid credentials, missing environment variables, " + f"user cancellation, network issues, or unsupported configuration." + ) from e + except Exception as e: + # Catch any other unexpected exceptions + raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e + +def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]: + """ + Process connection parameters and extract authentication type. + + Args: + parameters: List of connection string parameters + + Returns: + Tuple[list, Optional[str]]: Modified parameters and authentication type + + Raises: + ValueError: If an invalid authentication type is provided + """ + modified_parameters = [] + auth_type = None + + for param in parameters: + param = param.strip() + if not param: + continue + + if "=" not in param: + modified_parameters.append(param) + continue + + key, value = param.split("=", 1) + key_lower = key.lower() + value_lower = value.lower() + + if key_lower == "authentication": + # Check for supported authentication types and set auth_type accordingly + if value_lower == AuthType.INTERACTIVE.value: + # Interactive authentication (browser-based); only append parameter for non-Windows + if platform.system().lower() == "windows": + continue # Skip adding this parameter for Windows + auth_type = "interactive" + elif value_lower == AuthType.DEVICE_CODE.value: + # Device code authentication (for devices without browser) + auth_type = "devicecode" + elif value_lower == AuthType.DEFAULT.value: + # Default authentication (uses DefaultAzureCredential) + auth_type = "default" + modified_parameters.append(param) + + return modified_parameters, auth_type + +def remove_sensitive_params(parameters: list) -> list: + """Remove sensitive parameters from connection string""" + exclude_keys = [ + "uid=", "pwd=", "encrypt=", "trustservercertificate=", "authentication=" + ] + return [ + param for param in parameters + if not any(param.lower().startswith(exclude) for exclude in exclude_keys) + ] + +def get_auth_token(auth_type: str) -> Optional[bytes]: + """Get authentication token based on auth type""" + if not auth_type: + return None + + # Handle platform-specific logic for interactive auth + if auth_type == "interactive" and platform.system().lower() == "windows": + return None # Let Windows handle AADInteractive natively + + try: + return AADAuth.get_token(auth_type) + except (ValueError, RuntimeError): + return None + +def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict]]: + """ + Process connection string and handle authentication. + + Args: + connection_string: The connection string to process + + Returns: + Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed + + Raises: + ValueError: If the connection string is invalid or empty + """ + # Check type first + if not isinstance(connection_string, str): + raise ValueError("Connection string must be a string") + + # Then check if empty + if not connection_string: + raise ValueError("Connection string cannot be empty") + + parameters = connection_string.split(";") + + # Validate that there's at least one valid parameter + if not any('=' in param for param in parameters): + raise ValueError("Invalid connection string format") + + modified_parameters, auth_type = process_auth_parameters(parameters) + + if auth_type: + modified_parameters = remove_sensitive_params(modified_parameters) + token_struct = get_auth_token(auth_type) + if token_struct: + return ";".join(modified_parameters) + ";", {1256: token_struct} + + return ";".join(modified_parameters) + ";", None \ No newline at end of file diff --git a/mssql_python/connection.py b/mssql_python/connection.py index f706ee3ab..8456ef92d 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -11,6 +11,7 @@ - Cursors are also cleaned up automatically when no longer referenced, to prevent memory leaks. """ import weakref +import re from mssql_python.cursor import Cursor from mssql_python.logging_config import get_logger, ENABLE_LOGGING from mssql_python.constants import ConstantsDDBC as ddbc_sql_const @@ -18,6 +19,7 @@ from mssql_python import ddbc_bindings from mssql_python.pooling import PoolingManager from mssql_python.exceptions import DatabaseError, InterfaceError +from mssql_python.auth import process_connection_string logger = get_logger() @@ -64,6 +66,17 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef connection_str, **kwargs ) self._attrs_before = attrs_before or {} + + # Check if the connection string contains authentication parameters + # This is important for processing the connection string correctly. + # If authentication is specified, it will be processed to handle + # different authentication types like interactive, device code, etc. + if re.search(r"authentication", self.connection_str, re.IGNORECASE): + connection_result = process_connection_string(self.connection_str) + self.connection_str = connection_result[0] + if connection_result[1]: + self._attrs_before.update(connection_result[1]) + self._closed = False # Using WeakSet which automatically removes cursors when they are no longer in use diff --git a/mssql_python/constants.py b/mssql_python/constants.py index aade503c7..81e60d37e 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -116,3 +116,9 @@ class ConstantsDDBC(Enum): SQL_C_WCHAR = -8 SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 + +class AuthType(Enum): + """Constants for authentication types""" + INTERACTIVE = "activedirectoryinteractive" + DEVICE_CODE = "activedirectorydevicecode" + DEFAULT = "activedirectorydefault" \ No newline at end of file diff --git a/setup.py b/setup.py index 273ea0867..d5e855ede 100644 --- a/setup.py +++ b/setup.py @@ -100,6 +100,10 @@ def finalize_options(self): include_package_data=True, # Requires >= Python 3.10 python_requires='>=3.10', + # Add dependencies + install_requires=[ + 'azure-identity>=1.12.0', # Azure authentication library + ], classifiers=[ 'Operating System :: Microsoft :: Windows', 'Operating System :: MacOS', diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py new file mode 100644 index 000000000..2c3c1a0ad --- /dev/null +++ b/tests/test_008_auth.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +Tests for the auth module. +""" + +import pytest +import platform +import sys +from mssql_python.auth import ( + AADAuth, + process_auth_parameters, + remove_sensitive_params, + get_auth_token, + process_connection_string +) +from mssql_python.constants import AuthType +import secrets + +SAMPLE_TOKEN = secrets.token_hex(44) + +@pytest.fixture(autouse=True) +def setup_azure_identity(): + """Setup mock azure.identity module""" + class MockToken: + token = SAMPLE_TOKEN + + class MockDefaultAzureCredential: + def get_token(self, scope): + return MockToken() + + class MockDeviceCodeCredential: + def get_token(self, scope): + return MockToken() + + class MockInteractiveBrowserCredential: + def get_token(self, scope): + return MockToken() + + # Mock ClientAuthenticationError + class MockClientAuthenticationError(Exception): + pass + + class MockIdentity: + DefaultAzureCredential = MockDefaultAzureCredential + DeviceCodeCredential = MockDeviceCodeCredential + InteractiveBrowserCredential = MockInteractiveBrowserCredential + + class MockCore: + class exceptions: + ClientAuthenticationError = MockClientAuthenticationError + + # Create mock azure module if it doesn't exist + if 'azure' not in sys.modules: + sys.modules['azure'] = type('MockAzure', (), {})() + + # Add identity and core modules to azure + sys.modules['azure.identity'] = MockIdentity() + sys.modules['azure.core'] = MockCore() + sys.modules['azure.core.exceptions'] = MockCore.exceptions() + + yield + + # Cleanup + for module in ['azure.identity', 'azure.core', 'azure.core.exceptions']: + if module in sys.modules: + del sys.modules[module] + +class TestAuthType: + def test_auth_type_constants(self): + assert AuthType.INTERACTIVE.value == "activedirectoryinteractive" + assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" + assert AuthType.DEFAULT.value == "activedirectorydefault" + +class TestAADAuth: + def test_get_token_struct(self): + token_struct = AADAuth.get_token_struct(SAMPLE_TOKEN) + assert isinstance(token_struct, bytes) + assert len(token_struct) > 4 + + def test_get_token_default(self): + token_struct = AADAuth.get_token("default") + assert isinstance(token_struct, bytes) + + def test_get_token_device_code(self): + token_struct = AADAuth.get_token("devicecode") + assert isinstance(token_struct, bytes) + + def test_get_token_interactive(self): + token_struct = AADAuth.get_token("interactive") + assert isinstance(token_struct, bytes) + + def test_get_token_credential_mapping(self): + # Test that all supported auth types work + supported_types = ["default", "devicecode", "interactive"] + for auth_type in supported_types: + token_struct = AADAuth.get_token(auth_type) + assert isinstance(token_struct, bytes) + assert len(token_struct) > 4 + + def test_get_token_client_authentication_error(self): + """Test that ClientAuthenticationError is properly handled""" + from azure.core.exceptions import ClientAuthenticationError + + # Create a mock credential that raises ClientAuthenticationError + class MockFailingCredential: + def get_token(self, scope): + raise ClientAuthenticationError("Mock authentication failed") + + # Use monkeypatch to mock the credential creation + def mock_get_token_failing(auth_type): + from azure.core.exceptions import ClientAuthenticationError + if auth_type == "default": + try: + credential = MockFailingCredential() + token = credential.get_token("https://database.windows.net/.default").token + return AADAuth.get_token_struct(token) + except ClientAuthenticationError as e: + raise RuntimeError( + f"Azure AD authentication failed for MockFailingCredential: {e}. " + f"This could be due to invalid credentials, missing environment variables, " + f"user cancellation, network issues, or unsupported configuration." + ) from e + else: + return AADAuth.get_token(auth_type) + + with pytest.raises(RuntimeError, match="Azure AD authentication failed"): + mock_get_token_failing("default") + +class TestProcessAuthParameters: + def test_empty_parameters(self): + modified_params, auth_type = process_auth_parameters([]) + assert modified_params == [] + assert auth_type is None + + def test_interactive_auth_windows(self, monkeypatch): + monkeypatch.setattr(platform, "system", lambda: "Windows") + params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert "Authentication=ActiveDirectoryInteractive" not in modified_params + assert auth_type == None + + def test_interactive_auth_non_windows(self, monkeypatch): + monkeypatch.setattr(platform, "system", lambda: "Darwin") + params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert "Authentication=ActiveDirectoryInteractive" in modified_params + assert auth_type == "interactive" + + def test_device_code_auth(self): + params = ["Authentication=ActiveDirectoryDeviceCode", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert auth_type == "devicecode" + + def test_default_auth(self): + params = ["Authentication=ActiveDirectoryDefault", "Server=test"] + modified_params, auth_type = process_auth_parameters(params) + assert auth_type == "default" + +class TestRemoveSensitiveParams: + def test_remove_sensitive_parameters(self): + params = [ + "Server=test", + "UID=user", + "PWD=password", + "Encrypt=yes", + "TrustServerCertificate=yes", + "Authentication=ActiveDirectoryDefault", + "Database=testdb" + ] + filtered_params = remove_sensitive_params(params) + assert "Server=test" in filtered_params + assert "Database=testdb" in filtered_params + assert "UID=user" not in filtered_params + assert "PWD=password" not in filtered_params + assert "Encrypt=yes" not in filtered_params + assert "TrustServerCertificate=yes" not in filtered_params + assert "Authentication=ActiveDirectoryDefault" not in filtered_params + +class TestProcessConnectionString: + def test_process_connection_string_with_default_auth(self): + conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" + result_str, attrs = process_connection_string(conn_str) + + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert attrs is not None + assert 1256 in attrs + assert isinstance(attrs[1256], bytes) + + def test_process_connection_string_no_auth(self): + conn_str = "Server=test;Database=testdb;UID=user;PWD=password" + result_str, attrs = process_connection_string(conn_str) + + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert "UID=user" in result_str + assert "PWD=password" in result_str + assert attrs is None + + def test_process_connection_string_interactive_non_windows(self, monkeypatch): + monkeypatch.setattr(platform, "system", lambda: "Darwin") + conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb" + result_str, attrs = process_connection_string(conn_str) + + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert attrs is not None + assert 1256 in attrs + assert isinstance(attrs[1256], bytes) + +def test_error_handling(): + # Empty string should raise ValueError + with pytest.raises(ValueError, match="Connection string cannot be empty"): + process_connection_string("") + + # Invalid connection string should raise ValueError + with pytest.raises(ValueError, match="Invalid connection string format"): + process_connection_string("InvalidConnectionString") + + # Test non-string input + with pytest.raises(ValueError, match="Connection string must be a string"): + process_connection_string(None) \ No newline at end of file