-
Notifications
You must be signed in to change notification settings - Fork 46
Expand file tree
/
Copy pathauth.py
More file actions
168 lines (136 loc) · 6.13 KB
/
auth.py
File metadata and controls
168 lines (136 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
This module handles authentication for the mssql_python package.
"""
import platform
import struct
from typing import Tuple, Dict, Optional, Union
from mssql_python.logging_config import get_logger
from mssql_python.constants import AuthType
logger = get_logger()
class AADAuth:
"""Handles Azure Active Directory authentication"""
@staticmethod
def get_token_struct(token: str) -> bytes:
"""Convert token to SQL Server compatible format"""
token_bytes = token.encode("UTF-16-LE")
return struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes)
@staticmethod
def get_default_token() -> bytes:
"""Get token using DefaultAzureCredential"""
from azure.identity import DefaultAzureCredential
try:
# DefaultAzureCredential will automatically use the best available method
# based on the environment (e.g., managed identity, environment variables)
credential = DefaultAzureCredential()
token = credential.get_token("https://database.windows.net/.default").token
return AADAuth.get_token_struct(token)
except Exception as e:
raise RuntimeError(f"Failed to create DefaultAzureCredential: {e}")
@staticmethod
def get_device_code_token() -> bytes:
"""Get token using DeviceCodeCredential"""
from azure.identity import DeviceCodeCredential
try:
credential = DeviceCodeCredential()
token = credential.get_token("https://database.windows.net/.default").token
return AADAuth.get_token_struct(token)
except Exception as e:
raise RuntimeError(f"Failed to create DeviceCodeCredential: {e}")
@staticmethod
def get_interactive_token() -> bytes:
"""Get token using InteractiveBrowserCredential"""
from azure.identity import InteractiveBrowserCredential
try:
credential = InteractiveBrowserCredential()
token = credential.get_token("https://database.windows.net/.default").token
return AADAuth.get_token_struct(token)
except Exception as e:
raise RuntimeError(f"Failed to create InteractiveBrowserCredential: {e}")
def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
"""
Process connection parameters and extract authentication type.
Args:
parameters: List of connection string parameters
Returns:
Tuple[list, Optional[str]]: Modified parameters and authentication type
Raises:
ValueError: If an invalid authentication type is provided
"""
modified_parameters = []
auth_type = None
for param in parameters:
param = param.strip()
if not param:
continue
if "=" not in param:
modified_parameters.append(param)
continue
key, value = param.split("=", 1)
key_lower = key.lower()
value_lower = value.lower()
if key_lower == "authentication":
# Check for supported authentication types and set auth_type accordingly
if value_lower == AuthType.INTERACTIVE.value:
# Interactive authentication (browser-based); only append parameter for non-Windows
if platform.system().lower() == "windows":
continue # Skip adding this parameter for Windows
auth_type = "interactive"
elif value_lower == AuthType.DEVICE_CODE.value:
# Device code authentication (for devices without browser)
auth_type = "devicecode"
elif value_lower == AuthType.DEFAULT.value:
# Default authentication (uses DefaultAzureCredential)
auth_type = "default"
modified_parameters.append(param)
return modified_parameters, auth_type
def remove_sensitive_params(parameters: list) -> list:
"""Remove sensitive parameters from connection string"""
exclude_keys = [
"uid=", "pwd=", "encrypt=", "trustservercertificate=", "authentication="
]
return [
param for param in parameters
if not any(param.lower().startswith(exclude) for exclude in exclude_keys)
]
def get_auth_token(auth_type: str) -> Optional[bytes]:
"""Get authentication token based on auth type"""
if not auth_type:
return None
if auth_type == "default":
return AADAuth.get_default_token()
elif auth_type == "devicecode":
return AADAuth.get_device_code_token()
# If interactive authentication is requested, use InteractiveBrowserCredential
# but only if not on Windows, since in Windows: AADInteractive is supported.
elif auth_type == "interactive" and platform.system().lower() != "windows":
return AADAuth.get_interactive_token()
return None
def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict]]:
"""
Process connection string and handle authentication.
Args:
connection_string: The connection string to process
Returns:
Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed
Raises:
ValueError: If the connection string is invalid or empty
"""
# Check type first
if not isinstance(connection_string, str):
raise ValueError("Connection string must be a string")
# Then check if empty
if not connection_string:
raise ValueError("Connection string cannot be empty")
parameters = connection_string.split(";")
# Validate that there's at least one valid parameter
if not any('=' in param for param in parameters):
raise ValueError("Invalid connection string format")
modified_parameters, auth_type = process_auth_parameters(parameters)
if auth_type:
modified_parameters = remove_sensitive_params(modified_parameters)
token_struct = get_auth_token(auth_type)
if token_struct:
return ";".join(modified_parameters) + ";", {1256: token_struct}
return ";".join(modified_parameters) + ";", None