Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 110 additions & 25 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from mssql_python.logging_config import get_logger, ENABLE_LOGGING
from mssql_python import ddbc_bindings
from .row import Row
from typing import Sequence, Any


logger = get_logger()

Expand Down Expand Up @@ -539,7 +541,7 @@ def _map_data_type(self, sql_type):
# Add more mappings as needed
}
return sql_to_python_type.get(sql_type, str)

def execute(
self,
operation: str,
Expand Down Expand Up @@ -617,38 +619,121 @@ def execute(

# Initialize description after execution
self._initialize_description()


# def executemany(self, operation: str, seq_of_parameters: list) -> None:
# self._check_closed()
# self._reset_cursor()

# if not seq_of_parameters:
# return

# # Transpose to column-major format
# columns = list(zip(*seq_of_parameters)) # Each column: tuple of values
# sample_params = seq_of_parameters[0]
# param_info = ddbc_bindings.ParamInfo

# parameters_type = []
# for i, sample_val in enumerate(sample_params):
# paraminfo = self._create_parameter_types_list(sample_val, param_info, sample_params, i)

# # Fix: Adjust string column sizes based on actual max length across all rows
# if isinstance(sample_val, str):
# max_len = max(
# (len(v) for v in columns[i] if isinstance(v, str)),
# default=1 # fallback if all values are None
# )
# paraminfo.columnSize = max_len

# parameters_type.append(paraminfo)

# # Now execute with adjusted parameter types
# ret = ddbc_bindings.SQLExecuteMany(
# self.hstmt, operation, columns, parameters_type, len(seq_of_parameters)
# )
# check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

# self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
# self._initialize_description()

def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
Comment thread
gargsaumya marked this conversation as resolved.
Comment thread
gargsaumya marked this conversation as resolved.
"""
Convert list of rows (row-wise) into list of columns (column-wise),
for array binding via ODBC.

Example:
Input: [(1, "a"), (2, "b")]
Output: [[1, 2], ["a", "b"]]
"""
if not seq_of_parameters:
return []

num_params = len(seq_of_parameters[0])
columnwise = [[] for _ in range(num_params)]
for row in seq_of_parameters:
if len(row) != num_params:
raise ValueError("Inconsistent parameter row size in executemany()")
for i, val in enumerate(row):
columnwise[i].append(val)
return columnwise

def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Prepare a database operation and execute it against all parameter sequences.
Comment thread
gargsaumya marked this conversation as resolved.
This version uses column-wise parameter binding and a single batched SQLExecute().
"""
self._check_closed()
self._reset_cursor()

Args:
operation: SQL query or command.
seq_of_parameters: Sequence of sequences or mappings of parameters.
if not seq_of_parameters:
self.rowcount = 0
return

# # Infer types from the first row
# first_row = list(seq_of_parameters[0])
# param_info = ddbc_bindings.ParamInfo
# parameters_type = [
# self._create_parameter_types_list(param, param_info, first_row, i)
# for i, param in enumerate(first_row)
# ]
param_info = ddbc_bindings.ParamInfo
param_count = len(seq_of_parameters[0])
parameters_type = []

Raises:
Error: If the operation fails.
"""
self._check_closed() # Check if the cursor is closed
for col_index in range(param_count):
Comment thread
gargsaumya marked this conversation as resolved.
# Use the longest string (or most precise value) in that column for inference
column = [row[col_index] for row in seq_of_parameters]
sample_value = column[0]

self._reset_cursor()
# For strings, pick the value with max len
if isinstance(sample_value, str):
sample_value = max(column, key=lambda s: len(str(s)) if s is not None else 0)

first_execution = True
total_rowcount = 0
for parameters in seq_of_parameters:
parameters = list(parameters)
if ENABLE_LOGGING:
logger.info("Executing query with parameters: %s", parameters)
prepare_stmt = first_execution
first_execution = False
self.execute(
operation, parameters, use_prepare=prepare_stmt, reset_cursor=False
)
if self.rowcount != -1:
total_rowcount += self.rowcount
else:
total_rowcount = -1
self.rowcount = total_rowcount
# For decimals, use the one with highest precision
elif isinstance(sample_value, decimal.Decimal):
sample_value = max(column, key=lambda d: len(d.as_tuple().digits) if d is not None else 0)

param = sample_value
dummy_row = list(seq_of_parameters[0]) # to pass for `_get_numeric_data()` mutation
parameters_type.append(self._create_parameter_types_list(param, param_info, dummy_row, col_index))


# Transpose to column-wise format for array binding
columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters)

# Execute batched statement
ret = ddbc_bindings.SQLExecuteMany(
self.hstmt,
operation,
columnwise_params,
parameters_type,
len(seq_of_parameters)
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

self.rowcount = len(seq_of_parameters)
self.last_executed_stmt = operation
self._initialize_description()

def fetchone(self) -> Union[None, Row]:
"""
Expand Down
174 changes: 174 additions & 0 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@
return static_cast<ParamType*>(paramBuffers.back().get());
}

template <typename ParamType>
ParamType* AllocateParamBufferArray(std::vector<std::shared_ptr<void>>& paramBuffers,
size_t count) {
std::shared_ptr<ParamType> buffer(new ParamType[count], std::default_delete<ParamType[]>());
ParamType* raw = buffer.get();
paramBuffers.push_back(buffer);
return raw;
}

std::string DescribeChar(unsigned char ch) {
if (ch >= 32 && ch <= 126) {
return std::string("'") + static_cast<char>(ch) + "'";
Expand Down Expand Up @@ -933,6 +942,170 @@
}
}

SQLRETURN BindParameterArray(SQLHANDLE hStmt,
Comment thread
gargsaumya marked this conversation as resolved.
const py::list& columnwise_params,
const std::vector<ParamInfo>& paramInfos,
size_t paramSetSize,
std::vector<std::shared_ptr<void>>& paramBuffers) {
LOG("Starting column-wise parameter array binding. paramSetSize: {}, paramCount: {}", paramSetSize, columnwise_params.size());

for (int paramIndex = 0; paramIndex < columnwise_params.size(); ++paramIndex) {
const py::list& columnValues = columnwise_params[paramIndex].cast<py::list>();
const ParamInfo& info = paramInfos[paramIndex];

if (columnValues.size() != paramSetSize) {
ThrowStdException("Column " + std::to_string(paramIndex) + " has mismatched size.");
Comment thread
gargsaumya marked this conversation as resolved.
Outdated
}

void* dataPtr = nullptr;
SQLLEN* strLenOrIndArray = nullptr;
SQLLEN bufferLength = 0;

switch (info.paramCType) {
Comment thread
gargsaumya marked this conversation as resolved.
Outdated
case SQL_C_LONG: {
int* dataArray = AllocateParamBufferArray<int>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
if (!strLenOrIndArray)
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
dataArray[i] = 0;
strLenOrIndArray[i] = SQL_NULL_DATA;
} else {
dataArray[i] = columnValues[i].cast<int>();
if (strLenOrIndArray) strLenOrIndArray[i] = 0;
}
}
dataPtr = dataArray;
break;
}
case SQL_C_DOUBLE: {
double* dataArray = AllocateParamBufferArray<double>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
if (!strLenOrIndArray)
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
dataArray[i] = 0;
strLenOrIndArray[i] = SQL_NULL_DATA;
} else {
dataArray[i] = columnValues[i].cast<double>();
if (strLenOrIndArray) strLenOrIndArray[i] = 0;
}
}
dataPtr = dataArray;
break;
}
case SQL_C_WCHAR: {
SQLWCHAR* wcharArray = AllocateParamBufferArray<SQLWCHAR>(paramBuffers, paramSetSize * (info.columnSize + 1));
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
strLenOrIndArray[i] = SQL_NULL_DATA;
std::memset(wcharArray + i * (info.columnSize + 1), 0, (info.columnSize + 1) * sizeof(SQLWCHAR));
continue;
}

std::wstring wstr = columnValues[i].cast<std::wstring>();
if (wstr.length() > info.columnSize) {
Comment thread
gargsaumya marked this conversation as resolved.
Outdated
std::string offending = WideToUTF8(wstr);
ThrowStdException("String too long at param " + std::to_string(paramIndex) +
Comment thread
gargsaumya marked this conversation as resolved.
Outdated
", value: " + offending +
", len: " + std::to_string(wstr.length()) +
" > columnSize: " + std::to_string(info.columnSize));
}
std::memcpy(wcharArray + i * (info.columnSize + 1), wstr.c_str(), (wstr.length() + 1) * sizeof(SQLWCHAR));
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
strLenOrIndArray[i] = SQL_NTS;
}
dataPtr = wcharArray;
bufferLength = (info.columnSize + 1) * sizeof(SQLWCHAR);
break;
}
case SQL_C_TINYINT:
case SQL_C_UTINYINT: {
unsigned char* dataArray = AllocateParamBufferArray<unsigned char>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
const py::handle& value = columnValues[i];
if (!py::isinstance<py::int_>(value)) {
ThrowStdException(MakeParamMismatchErrorStr(info.paramCType, paramIndex));
}
int intVal = value.cast<int>();
if (intVal < 0 || intVal > 255) {
ThrowStdException("UTINYINT value out of range at rowIndex " + std::to_string(i));
}
dataArray[i] = static_cast<unsigned char>(intVal);
}
dataPtr = dataArray;
bufferLength = sizeof(unsigned char);
break;
}
case SQL_C_SHORT: {
short* dataArray = AllocateParamBufferArray<short>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
const py::handle& value = columnValues[i];
if (!py::isinstance<py::int_>(value)) {
ThrowStdException(MakeParamMismatchErrorStr(info.paramCType, paramIndex));
}
int intVal = value.cast<int>();
if (intVal < std::numeric_limits<short>::min() ||
intVal > std::numeric_limits<short>::max()) {
ThrowStdException("SHORT value out of range at rowIndex " + std::to_string(i));
}
dataArray[i] = static_cast<short>(intVal);
}
dataPtr = dataArray;
bufferLength = sizeof(short);
break;
}

default: {
ThrowStdException("BindParameterArray: Unsupported C type: " + std::to_string(info.paramCType));
}
}

RETCODE rc = SQLBindParameter_ptr(
hStmt,
static_cast<SQLUSMALLINT>(paramIndex + 1),
static_cast<SQLUSMALLINT>(info.inputOutputType),
static_cast<SQLSMALLINT>(info.paramCType),
static_cast<SQLSMALLINT>(info.paramSQLType),
info.columnSize,
info.decimalDigits,
dataPtr,
bufferLength,
strLenOrIndArray
);
if (!SQL_SUCCEEDED(rc)) {
LOG("Failed to bind array param {}", paramIndex);
return rc;
}
}
LOG("Finished column-wise parameter array binding.");
return SQL_SUCCESS;
}

SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle,
const std::wstring& query,
const py::list& columnwise_params,
const std::vector<ParamInfo>& paramInfos,
size_t paramSetSize) {
SQLHANDLE hStmt = statementHandle->get();
SQLWCHAR* queryPtr;
#if defined(__APPLE__) || defined(__linux__)
std::vector<SQLWCHAR> queryBuffer = WStringToSQLWCHAR(query);
queryPtr = queryBuffer.data();
#else
queryPtr = const_cast<SQLWCHAR*>(query.c_str());
#endif
RETCODE rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS);
if (!SQL_SUCCEEDED(rc)) return rc;
std::vector<std::shared_ptr<void>> paramBuffers;
rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers);
if (!SQL_SUCCEEDED(rc)) return rc;
rc = SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_PARAMSET_SIZE, (SQLPOINTER)paramSetSize, 0);
if (!SQL_SUCCEEDED(rc)) return rc;
rc = SQLExecute_ptr(hStmt);
return rc;
}

// Wrap SQLNumResultCols
SQLSMALLINT SQLNumResultCols_wrap(SqlHandlePtr statementHandle) {
LOG("Get number of columns in result set");
Expand Down Expand Up @@ -2112,6 +2285,7 @@
m.def("close_pooling", []() {ConnectionPoolManager::getInstance().closePools();});
m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly");
m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements");
m.def("SQLExecuteMany", &SQLExecuteMany_wrap, "Execute statement with multiple parameter sets");
m.def("DDBCSQLRowCount", &SQLRowCount_wrap,
"Get the number of rows affected by the last statement");
m.def("DDBCSQLFetch", &SQLFetch_wrap, "Fetch the next row from the result set");
Expand Down
Loading