|
36 | 36 | ) |
37 | 37 |
|
38 | 38 | if TYPE_CHECKING: |
| 39 | + import pyarrow # type: ignore |
39 | 40 | from mssql_python.connection import Connection |
| 41 | +else: |
| 42 | + pyarrow = None |
40 | 43 |
|
41 | 44 | # Constants for string handling |
42 | 45 | MAX_INLINE_CHAR: int = ( |
@@ -788,6 +791,19 @@ def _check_closed(self) -> None: |
788 | 791 | ddbc_error="", |
789 | 792 | ) |
790 | 793 |
|
| 794 | + def _ensure_pyarrow(self) -> Any: |
| 795 | + """ |
| 796 | + Import and return pyarrow or raise ImportError accordingly. |
| 797 | + """ |
| 798 | + try: |
| 799 | + import pyarrow |
| 800 | + |
| 801 | + return pyarrow |
| 802 | + except ImportError as e: |
| 803 | + raise ImportError( |
| 804 | + "pyarrow is required for Arrow fetch methods. Please install pyarrow." |
| 805 | + ) from e |
| 806 | + |
791 | 807 | def setinputsizes(self, sizes: List[Union[int, tuple]]) -> None: |
792 | 808 | """ |
793 | 809 | Sets the type information to be used for parameters in execute and executemany. |
@@ -2540,6 +2556,94 @@ def fetchall(self) -> List[Row]: |
2540 | 2556 | # On error, don't increment rownumber - rethrow the error |
2541 | 2557 | raise e |
2542 | 2558 |
|
| 2559 | + def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch": |
| 2560 | + """ |
| 2561 | + Fetch a single pyarrow Record Batch of the specified size from the |
| 2562 | + query result set. |
| 2563 | +
|
| 2564 | + Args: |
| 2565 | + batch_size: Maximum number of rows to fetch in the Record Batch. |
| 2566 | +
|
| 2567 | + Returns: |
| 2568 | + A pyarrow RecordBatch object containing up to batch_size rows. |
| 2569 | + """ |
| 2570 | + self._check_closed() # Check if the cursor is closed |
| 2571 | + pyarrow = self._ensure_pyarrow() |
| 2572 | + |
| 2573 | + if not self._has_result_set and self.description: |
| 2574 | + self._reset_rownumber() |
| 2575 | + |
| 2576 | + capsules = [] |
| 2577 | + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0)) |
| 2578 | + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) |
| 2579 | + |
| 2580 | + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) |
| 2581 | + |
| 2582 | + if self.hstmt: |
| 2583 | + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) |
| 2584 | + |
| 2585 | + # Update rownumber for the number of rows actually fetched |
| 2586 | + num_fetched = batch.num_rows |
| 2587 | + if num_fetched > 0 and self._has_result_set: |
| 2588 | + self._next_row_index += num_fetched |
| 2589 | + self._rownumber = self._next_row_index - 1 |
| 2590 | + |
| 2591 | + # Centralize rowcount assignment after fetch |
| 2592 | + if num_fetched == 0 and self._next_row_index == 0: |
| 2593 | + self.rowcount = 0 |
| 2594 | + else: |
| 2595 | + self.rowcount = self._next_row_index |
| 2596 | + |
| 2597 | + return batch |
| 2598 | + |
| 2599 | + def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": |
| 2600 | + """ |
| 2601 | + Fetch the entire result as a pyarrow Table. |
| 2602 | +
|
| 2603 | + Args: |
| 2604 | + batch_size: Size of the Record Batches which make up the Table. |
| 2605 | +
|
| 2606 | + Returns: |
| 2607 | + A pyarrow Table containing all remaining rows from the result set. |
| 2608 | + """ |
| 2609 | + self._check_closed() # Check if the cursor is closed |
| 2610 | + pyarrow = self._ensure_pyarrow() |
| 2611 | + |
| 2612 | + batches: list["pyarrow.RecordBatch"] = [] |
| 2613 | + while True: |
| 2614 | + batch = self.arrow_batch(batch_size) |
| 2615 | + if batch.num_rows < batch_size or batch_size <= 0: |
| 2616 | + if not batches or batch.num_rows > 0: |
| 2617 | + batches.append(batch) |
| 2618 | + break |
| 2619 | + batches.append(batch) |
| 2620 | + return pyarrow.Table.from_batches(batches, schema=batches[0].schema) |
| 2621 | + |
| 2622 | + def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": |
| 2623 | + """ |
| 2624 | + Fetch the result as a pyarrow RecordBatchReader, which yields Record |
| 2625 | + Batches of the specified size until the current result set is |
| 2626 | + exhausted. |
| 2627 | +
|
| 2628 | + Args: |
| 2629 | + batch_size: Size of the Record Batches produced by the reader. |
| 2630 | +
|
| 2631 | + Returns: |
| 2632 | + A pyarrow RecordBatchReader for the result set. |
| 2633 | + """ |
| 2634 | + self._check_closed() # Check if the cursor is closed |
| 2635 | + pyarrow = self._ensure_pyarrow() |
| 2636 | + |
| 2637 | + # Fetch schema without advancing cursor |
| 2638 | + schema_batch = self.arrow_batch(0) |
| 2639 | + schema = schema_batch.schema |
| 2640 | + |
| 2641 | + def batch_generator(): |
| 2642 | + while (batch := self.arrow_batch(batch_size)).num_rows > 0: |
| 2643 | + yield batch |
| 2644 | + |
| 2645 | + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) |
| 2646 | + |
2543 | 2647 | def nextset(self) -> Union[bool, None]: |
2544 | 2648 | """ |
2545 | 2649 | Skip to the next available result set. |
|
0 commit comments