Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
preparing it for further operations such as connecting to the
database, executing queries, etc.
"""
self.connection_str = self._construct_connection_string(
connection_str, **kwargs
)
self._attrs_before = attrs_before or {}
# Get connection string and potential attrs_before from construction
connection_result = self._construct_connection_string(connection_str, **kwargs)

if isinstance(connection_result, tuple):
self.connection_str, attrs_from_driver = connection_result
# Merge with any existing attrs_before
self._attrs_before = attrs_before or {}
self._attrs_before.update(attrs_from_driver)
else:
self.connection_str = connection_result
self._attrs_before = attrs_before or {}

self._closed = False

# Using WeakSet which automatically removes cursors when they are no longer in use
Expand All @@ -90,10 +98,18 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st
**kwargs: Additional key/value pairs for the connection string.

Returns:
str: The constructed connection string.
Union[str, Tuple[str, dict]]: Either the constructed connection string,
or a tuple of (connection string, attrs_before dict)
"""
# Add the driver attribute to the connection string
conn_str = add_driver_to_connection_str(connection_str)
result = add_driver_to_connection_str(connection_str)

# Handle both string and tuple return types
if isinstance(result, tuple):
conn_str, attrs_before = result
else:
conn_str = result
attrs_before = None

# Add additional key-value pairs to the connection string
for key, value in kwargs.items():
Expand All @@ -116,6 +132,8 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st
if ENABLE_LOGGING:
logger.info("Final connection string: %s", conn_str)

if attrs_before:
Comment thread
jahnvi480 marked this conversation as resolved.
Outdated
return conn_str, attrs_before
return conn_str

@property
Expand Down
59 changes: 46 additions & 13 deletions mssql_python/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@ def add_driver_to_connection_str(connection_str):
connection_str (str): The original connection string.

Returns:
str: The connection string with the DDBC driver added.

Raises:
Exception: If the connection string is invalid.
Union[str, Tuple[str, dict]]: Either the connection string with driver added,
or a tuple of (connection string, attrs_before dict)
"""
driver_name = "Driver={ODBC Driver 18 for SQL Server}"
try:
# Strip any leading or trailing whitespace from the connection string
connection_str = connection_str.strip()
connection_str = add_driver_name_to_app_parameter(connection_str)
result = add_driver_name_to_app_parameter(connection_str)

# Handle both regular string and tuple return types
attrs_before = None
if isinstance(result, tuple):
connection_str, attrs_before = result
else:
connection_str = result

# Split the connection string into individual attributes
connection_attributes = connection_str.split(";")
Expand All @@ -50,15 +55,16 @@ def add_driver_to_connection_str(connection_str):
final_connection_attributes.insert(0, driver_name)
connection_str = ";".join(final_connection_attributes)

if attrs_before:
Comment thread
jahnvi480 marked this conversation as resolved.
Outdated
return connection_str, attrs_before
return connection_str

except Exception as e:
raise Exception(
"Invalid connection string, Please follow the format: "
"Server=server_name;Database=database_name;UID=user_name;PWD=password"
) from e

return connection_str


def check_error(handle_type, handle, ret):
"""
Check for errors and raise an exception if an error is found.
Expand All @@ -80,37 +86,64 @@ def check_error(handle_type, handle, ret):

def add_driver_name_to_app_parameter(connection_string):
"""
Modifies the input connection string by appending the APP name.
Modifies the input connection string by appending the APP name and handling AAD auth.

Args:
connection_string (str): The input connection string.

Returns:
str: The modified connection string.
Union[str, Tuple[str, bytes]]: Either the modified connection string,
or a tuple of (connection string, token bytes) if AAD auth is needed
"""
import sys

# Split the input string into key-value pairs
parameters = connection_string.split(";")

# Initialize variables
app_found = False
modified_parameters = []
has_aad_interactive = False

# Iterate through the key-value pairs
for param in parameters:
param = param.strip()
if not param:
continue

if sys.platform.startswith("win"):
if param.lower().startswith("authentication="):
# Handle AAD Interactive authentication
key, auth_value = param.split("=", 1)
if auth_value.lower() == "activedirectoryinteractive":
has_aad_interactive = True
# Only keep the auth parameter on Windows
modified_parameters.append(param)
continue
if param.lower().startswith("app="):
# Overwrite the value with 'MSSQL-Python'
app_found = True
key, _ = param.split("=", 1)
modified_parameters.append(f"{key}=MSSQL-Python")
else:
# Keep other parameters as is
modified_parameters.append(param)

# If APP key is not found, append it
if not app_found:
modified_parameters.append("APP=MSSQL-Python")

# Join the parameters back into a connection string
# Handle AAD Interactive auth for non-Windows platforms
if has_aad_interactive and platform.system().lower() != "windows":
try:
from azure.identity import InteractiveBrowserCredential
import struct
except ImportError:
raise ImportError("Please install azure-identity: pip install azure-identity")

credential = InteractiveBrowserCredential()
token_bytes = credential.get_token("https://database.windows.net/.default").token.encode("UTF-16-LE")
token_struct = struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes)
return ";".join(modified_parameters) + ";", {1256: token_struct}

return ";".join(modified_parameters) + ";"


Expand Down
4 changes: 2 additions & 2 deletions tests/test_003_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
assert "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"
assert "Driver={ODBC Driver 18 for SQL Server};APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"
Comment thread Dismissed

def test_connection_string_with_attrs_before(db_connection):
# Check if the connection string is constructed correctly with attrs_before
Expand All @@ -70,7 +70,7 @@
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
assert "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"
assert "Driver={ODBC Driver 18 for SQL Server};APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"
Comment thread Dismissed

def test_autocommit_default(db_connection):
assert db_connection.autocommit is True, "Autocommit should be True by default"
Expand Down
Loading