Skip to content

Commit 3f721c4

Browse files
authored
Merge branch 'main' into jahnvi/ghissue_203
2 parents 8f52c85 + b786900 commit 3f721c4

File tree

6 files changed

+1736
-0
lines changed

6 files changed

+1736
-0
lines changed

mssql_python/cursor.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
)
3737

3838
if TYPE_CHECKING:
39+
import pyarrow # type: ignore
3940
from mssql_python.connection import Connection
41+
else:
42+
pyarrow = None
4043

4144
# Constants for string handling
4245
MAX_INLINE_CHAR: int = (
@@ -788,6 +791,19 @@ def _check_closed(self) -> None:
788791
ddbc_error="",
789792
)
790793

794+
def _ensure_pyarrow(self) -> Any:
795+
"""
796+
Import and return pyarrow or raise ImportError accordingly.
797+
"""
798+
try:
799+
import pyarrow
800+
801+
return pyarrow
802+
except ImportError as e:
803+
raise ImportError(
804+
"pyarrow is required for Arrow fetch methods. Please install pyarrow."
805+
) from e
806+
791807
def setinputsizes(self, sizes: List[Union[int, tuple]]) -> None:
792808
"""
793809
Sets the type information to be used for parameters in execute and executemany.
@@ -2540,6 +2556,94 @@ def fetchall(self) -> List[Row]:
25402556
# On error, don't increment rownumber - rethrow the error
25412557
raise e
25422558

2559+
def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch":
2560+
"""
2561+
Fetch a single pyarrow Record Batch of the specified size from the
2562+
query result set.
2563+
2564+
Args:
2565+
batch_size: Maximum number of rows to fetch in the Record Batch.
2566+
2567+
Returns:
2568+
A pyarrow RecordBatch object containing up to batch_size rows.
2569+
"""
2570+
self._check_closed() # Check if the cursor is closed
2571+
pyarrow = self._ensure_pyarrow()
2572+
2573+
if not self._has_result_set and self.description:
2574+
self._reset_rownumber()
2575+
2576+
capsules = []
2577+
ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0))
2578+
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
2579+
2580+
batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules)
2581+
2582+
if self.hstmt:
2583+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
2584+
2585+
# Update rownumber for the number of rows actually fetched
2586+
num_fetched = batch.num_rows
2587+
if num_fetched > 0 and self._has_result_set:
2588+
self._next_row_index += num_fetched
2589+
self._rownumber = self._next_row_index - 1
2590+
2591+
# Centralize rowcount assignment after fetch
2592+
if num_fetched == 0 and self._next_row_index == 0:
2593+
self.rowcount = 0
2594+
else:
2595+
self.rowcount = self._next_row_index
2596+
2597+
return batch
2598+
2599+
def arrow(self, batch_size: int = 8192) -> "pyarrow.Table":
2600+
"""
2601+
Fetch the entire result as a pyarrow Table.
2602+
2603+
Args:
2604+
batch_size: Size of the Record Batches which make up the Table.
2605+
2606+
Returns:
2607+
A pyarrow Table containing all remaining rows from the result set.
2608+
"""
2609+
self._check_closed() # Check if the cursor is closed
2610+
pyarrow = self._ensure_pyarrow()
2611+
2612+
batches: list["pyarrow.RecordBatch"] = []
2613+
while True:
2614+
batch = self.arrow_batch(batch_size)
2615+
if batch.num_rows < batch_size or batch_size <= 0:
2616+
if not batches or batch.num_rows > 0:
2617+
batches.append(batch)
2618+
break
2619+
batches.append(batch)
2620+
return pyarrow.Table.from_batches(batches, schema=batches[0].schema)
2621+
2622+
def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader":
2623+
"""
2624+
Fetch the result as a pyarrow RecordBatchReader, which yields Record
2625+
Batches of the specified size until the current result set is
2626+
exhausted.
2627+
2628+
Args:
2629+
batch_size: Size of the Record Batches produced by the reader.
2630+
2631+
Returns:
2632+
A pyarrow RecordBatchReader for the result set.
2633+
"""
2634+
self._check_closed() # Check if the cursor is closed
2635+
pyarrow = self._ensure_pyarrow()
2636+
2637+
# Fetch schema without advancing cursor
2638+
schema_batch = self.arrow_batch(0)
2639+
schema = schema_batch.schema
2640+
2641+
def batch_generator():
2642+
while (batch := self.arrow_batch(batch_size)).num_rows > 0:
2643+
yield batch
2644+
2645+
return pyarrow.RecordBatchReader.from_batches(schema, batch_generator())
2646+
25432647
def nextset(self) -> Union[bool, None]:
25442648
"""
25452649
Skip to the next available result set.

mssql_python/mssql_python.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Type stubs for mssql_python package - based on actual public API
77
from typing import Any, Dict, List, Optional, Union, Tuple, Sequence, Callable, Iterator
88
import datetime
99
import logging
10+
import pyarrow
1011

1112
# GLOBALS - DB-API 2.0 Required Module Globals
1213
# https://www.python.org/dev/peps/pep-0249/#module-interface
@@ -199,6 +200,11 @@ class Cursor:
199200
def setinputsizes(self, sizes: List[Union[int, Tuple[Any, ...]]]) -> None: ...
200201
def setoutputsize(self, size: int, column: Optional[int] = None) -> None: ...
201202

203+
# Arrow Extension Methods (requires pyarrow)
204+
def arrow_batch(self, batch_size: int = 8192) -> pyarrow.RecordBatch: ...
205+
def arrow(self, batch_size: int = 8192) -> pyarrow.Table: ...
206+
def arrow_reader(self, batch_size: int = 8192) -> pyarrow.RecordBatchReader: ...
207+
202208
# DB-API 2.0 Connection Object
203209
# https://www.python.org/dev/peps/pep-0249/#connection-objects
204210
class Connection:

0 commit comments

Comments
 (0)