Skip to content

Commit c8345a5

Browse files
committed
Conform to spec by reading as pyarrow.Table not pyarrow.RecordBatch
1 parent 3db3f38 commit c8345a5

File tree

9 files changed

+268
-159
lines changed

9 files changed

+268
-159
lines changed

apis/python/src/tiledbsoma/soma_dataframe.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def read(
169169
# TODO: batch_size
170170
# TODO: partition,
171171
# TODO: platform_config,
172-
) -> Iterator[pa.RecordBatch]:
172+
) -> Iterator[pa.Table]:
173173
"""
174-
Read a user-defined subset of data, addressed by the dataframe indexing column, optionally filtered, and return results as one or more ``Arrow.RecordBatch``.
174+
Read a user-defined subset of data, addressed by the dataframe indexing column, optionally filtered, and return results as one or more ``Arrow.Table``.
175175
176176
:param ids: Which rows to read. Defaults to ``None``, meaning no constraint -- all rows.
177177
@@ -217,18 +217,16 @@ def read(
217217
else:
218218
iterator = query.df[ids]
219219

220-
for df in iterator:
221-
batches = df.to_batches()
222-
for batch in batches:
223-
# XXX COMMENT MORE
224-
# This is the 'decode on read' part of our logic; in dim_select we have the
225-
# 'encode on write' part.
226-
# Context: https://github.com/single-cell-data/TileDB-SOMA/issues/99.
227-
#
228-
# Also: don't materialize these on read
229-
# TODO: get the arrow syntax for drop
230-
# df.drop(ROWID, axis=1)
231-
yield util_arrow.ascii_to_unicode_pyarrow_readback(batch)
220+
for table in iterator:
221+
# XXX COMMENT MORE
222+
# This is the 'decode on read' part of our logic; in dim_select we have the
223+
# 'encode on write' part.
224+
# Context: https://github.com/single-cell-data/TileDB-SOMA/issues/99.
225+
#
226+
# Also: don't materialize these on read
227+
# TODO: get the arrow syntax for drop
228+
# df.drop(ROWID, axis=1)
229+
yield util_arrow.ascii_to_unicode_pyarrow_readback(table)
232230

233231
def read_all(
234232
self,
@@ -243,11 +241,11 @@ def read_all(
243241
# TODO: partition,
244242
# TODO: result_order,
245243
# TODO: platform_config,
246-
) -> pa.RecordBatch:
244+
) -> pa.Table:
247245
"""
248-
This is a convenience method around ``read``. It iterates the return value from ``read`` and returns a concatenation of all the record batches found. Its nominal use is to simply unit-test cases.
246+
This is a convenience method around ``read``. It iterates the return value from ``read`` and returns a concatenation of all the table-pieces found. Its nominal use is to simply unit-test cases.
249247
"""
250-
return util_arrow.concat_batches(
248+
return util_arrow.concat_tables(
251249
self.read(
252250
ids=ids,
253251
value_filter=value_filter,
@@ -273,13 +271,13 @@ def _get_is_sparse(self) -> bool:
273271

274272
return self._cached_is_sparse
275273

276-
def write(self, values: pa.RecordBatch) -> None:
274+
def write(self, values: pa.Table) -> None:
277275
"""
278-
Write an Arrow.RecordBatch to the persistent object.
276+
Write an Arrow.Table to the persistent object.
279277
280-
:param values: An Arrow.RecordBatch containing all columns, including the index columns. The schema for the values must match the schema for the ``SOMADataFrame``.
278+
:param values: An Arrow.Table containing all columns, including the index columns. The schema for the values must match the schema for the ``SOMADataFrame``.
281279
282-
The ``values`` Arrow RecordBatch must contain a ``soma_rowid`` (uint64) column, indicating which rows are being written.
280+
The ``values`` Arrow Table must contain a ``soma_rowid`` (uint64) column, indicating which rows are being written.
283281
"""
284282
self._shape = None # cache-invalidate
285283

apis/python/src/tiledbsoma/soma_dense_nd_array.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ def read_tensor(
168168
)
169169

170170
def read_numpy(
171+
if row_ids is None:
172+
if col_ids is None:
173+
iterator = query.df[:, :]
174+
else:
175+
iterator = query.df[:, col_ids]
176+
else:
177+
if col_ids is None:
178+
iterator = query.df[row_ids, :]
179+
else:
180+
iterator = query.df[row_ids, col_ids]
181+
182+
for table in iterator:
183+
yield table
184+
185+
def read_as_pandas(
171186
self,
172187
coords: SOMADenseNdCoordinates,
173188
*,
@@ -178,11 +193,77 @@ def read_numpy(
178193
"""
179194
return cast(
180195
np.ndarray, self.read_tensor(coords, result_order=result_order).to_numpy()
196+
with self._tiledb_open() as A:
197+
query = A.query(return_incomplete=True)
198+
199+
if row_ids is None:
200+
if col_ids is None:
201+
iterator = query.df[:, :]
202+
else:
203+
iterator = query.df[:, col_ids]
204+
else:
205+
if col_ids is None:
206+
iterator = query.df[row_ids, :]
207+
else:
208+
iterator = query.df[row_ids, col_ids]
209+
210+
for df in iterator:
211+
# Make this opt-in only. For large arrays, this df.set_index is time-consuming
212+
# so we should not do it without direction.
213+
if set_index:
214+
df.set_index(self._tiledb_dim_names(), inplace=True)
215+
yield df
216+
217+
def read_all(
218+
self,
219+
*,
220+
# TODO: find the right syntax to get the typechecker to accept args like ``ids=slice(0,10)``
221+
# ids: Optional[Union[Sequence[int], Slice]] = None,
222+
row_ids: Optional[Sequence[int]] = None,
223+
col_ids: Optional[Sequence[int]] = None,
224+
result_order: Optional[str] = None,
225+
# TODO: batch_size
226+
# TODO: partition,
227+
# TODO: batch_format,
228+
# TODO: platform_config,
229+
) -> pa.Table:
230+
"""
231+
This is a convenience method around ``read``. It iterates the return value from ``read`` and returns a concatenation of all the table-pieces found. Its nominal use is to simply unit-test cases.
232+
"""
233+
return util_arrow.concat_tables(
234+
self.read(
235+
row_ids=row_ids,
236+
col_ids=col_ids,
237+
result_order=result_order,
238+
)
181239
)
182240

183241
def write_tensor(
184242
self,
185243
coords: SOMADenseNdCoordinates,
244+
*,
245+
row_ids: Optional[Sequence[int]] = None,
246+
col_ids: Optional[Sequence[int]] = None,
247+
set_index: Optional[bool] = False,
248+
) -> pa.Table:
249+
"""
250+
This is a convenience method around ``read_as_pandas``. It iterates the return value from ``read_as_pandas`` and returns a concatenation of all the table-pieces found. Its nominal use is to simply unit-test cases.
251+
"""
252+
dataframes = []
253+
generator = self.read_as_pandas(
254+
row_ids=row_ids,
255+
col_ids=col_ids,
256+
set_index=set_index,
257+
)
258+
for dataframe in generator:
259+
dataframes.append(dataframe)
260+
return pd.concat(dataframes)
261+
262+
def write(
263+
self,
264+
# TODO: rework callsites with regard to the very latest spec rev
265+
# coords: Union[tuple, tuple[slice], NTuple, List[int]],
266+
coords: Any,
186267
values: pa.Tensor,
187268
) -> None:
188269
"""

apis/python/src/tiledbsoma/soma_indexed_dataframe.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ def read(
209209
column_names: Optional[Sequence[str]] = None,
210210
result_order: Optional[SOMAResultOrder] = None,
211211
# TODO: more arguments
212-
) -> Iterator[pa.RecordBatch]:
212+
) -> Iterator[pa.Table]:
213213
"""
214-
Read a user-defined subset of data, addressed by the dataframe indexing columns, optionally filtered, and return results as one or more Arrow.RecordBatch.
214+
Read a user-defined subset of data, addressed by the dataframe indexing columns, optionally filtered, and return results as one or more Arrow.Table.
215215
216216
:param ids: for each index dimension, which rows to read. Defaults to ``None``, meaning no constraint -- all IDs.
217217
@@ -258,14 +258,12 @@ def read(
258258
else:
259259
iterator = query.df[ids]
260260

261-
for df in iterator:
262-
batches = df.to_batches()
263-
for batch in batches:
264-
# XXX COMMENT MORE
265-
# This is the 'decode on read' part of our logic; in dim_select we have the
266-
# 'encode on write' part.
267-
# Context: # https://github.com/single-cell-data/TileDB-SOMA/issues/99.
268-
yield util_arrow.ascii_to_unicode_pyarrow_readback(batch)
261+
for table in iterator:
262+
# XXX COMMENT MORE
263+
# This is the 'decode on read' part of our logic; in dim_select we have the
264+
# 'encode on write' part.
265+
# Context: # https://github.com/single-cell-data/TileDB-SOMA/issues/99.
266+
yield util_arrow.ascii_to_unicode_pyarrow_readback(table)
269267

270268
def read_all(
271269
self,
@@ -279,19 +277,19 @@ def read_all(
279277
# TODO: batch_size
280278
# TODO: partition,
281279
# TODO: platform_config,
282-
) -> pa.RecordBatch:
280+
) -> pa.Table:
283281
"""
284-
This is a convenience method around ``read``. It iterates the return value from ``read`` and returns a concatenation of all the record batches found. Its nominal use is to simply unit-test cases.
282+
This is a convenience method around ``read``. It iterates the return value from ``read`` and returns a concatenation of all the table-pieces found. Its nominal use is to simply unit-test cases.
285283
"""
286-
return util_arrow.concat_batches(
284+
return util_arrow.concat_tables(
287285
self.read(ids=ids, value_filter=value_filter, column_names=column_names)
288286
)
289287

290-
def write(self, values: pa.RecordBatch) -> None:
288+
def write(self, values: pa.Table) -> None:
291289
"""
292-
Write an Arrow.RecordBatch to the persistent object. As duplicate index values are not allowed, index values already present in the object are overwritten and new index values are added.
290+
Write an Arrow.Table to the persistent object. As duplicate index values are not allowed, index values already present in the object are overwritten and new index values are added.
293291
294-
:param values: An Arrow.RecordBatch containing all columns, including the index columns. The schema for the values must match the schema for the ``SOMAIndexedDataFrame``.
292+
:param values: An Arrow.Table containing all columns, including the index columns. The schema for the values must match the schema for the ``SOMAIndexedDataFrame``.
295293
"""
296294
self._shape = None # cache-invalidate
297295

apis/python/src/tiledbsoma/soma_sparse_nd_array.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,17 @@ def read_sparse_tensor(
158158
*,
159159
format: Literal["coo", "csr", "csc"] = "coo",
160160
) -> Iterator[Union[pa.SparseCOOTensor, pa.SparseCSCMatrix, pa.SparseCSRMatrix]]:
161+
# TODO: find the right syntax to get the typechecker to accept args like ``ids=slice(0,10)``
162+
# row_ids: Optional[Union[Sequence[int], Slice]] = None,
163+
# col_ids: Optional[Union[Sequence[int], Slice]] = None,
164+
row_ids: Optional[Sequence[int]] = None,
165+
col_ids: Optional[Sequence[int]] = None,
166+
result_order: Optional[str] = None,
167+
# TODO: batch_size
168+
# TODO: partition,
169+
# TODO: batch_format,
170+
# TODO: platform_config,
171+
) -> Iterator[pa.Table]:
161172
"""
162173
Read a use-defined slice of the SparseNdArray and return as an Arrow sparse tensor.
163174
@@ -215,6 +226,16 @@ def read_sparse_tensor(
215226
yield pa.SparseCSCMatrix.from_scipy(scipy_coo.tocsc())
216227

217228
def read_table(self, coords: SOMASparseNdCoordinates) -> Iterator[pa.Table]:
229+
for table in iterator:
230+
yield table
231+
232+
def read_as_pandas(
233+
self,
234+
*,
235+
row_ids: Optional[Sequence[int]] = None,
236+
col_ids: Optional[Sequence[int]] = None,
237+
set_index: Optional[bool] = False,
238+
) -> pd.DataFrame:
218239
"""
219240
Read a user-defined slice of the sparse array and return in COO format
220241
as an Arrow Table
@@ -223,6 +244,53 @@ def read_table(self, coords: SOMASparseNdCoordinates) -> Iterator[pa.Table]:
223244
query = A.query(
224245
return_arrow=True,
225246
return_incomplete=True,
247+
dim_names = None
248+
if set_index:
249+
dim_names = self._tiledb_dim_names()
250+
251+
with self._tiledb_open() as A:
252+
query = A.query(return_incomplete=True)
253+
254+
if row_ids is None:
255+
if col_ids is None:
256+
iterator = query.df[:, :]
257+
else:
258+
iterator = query.df[:, col_ids]
259+
else:
260+
if col_ids is None:
261+
iterator = query.df[row_ids, :]
262+
else:
263+
iterator = query.df[row_ids, col_ids]
264+
265+
for df in iterator:
266+
# Make this opt-in only. For large arrays, this df.set_index is time-consuming
267+
# so we should not do it without direction.
268+
if set_index:
269+
df.set_index(dim_names, inplace=True)
270+
yield df
271+
272+
def read_all(
273+
self,
274+
*,
275+
# TODO: find the right syntax to get the typechecker to accept args like ``ids=slice(0,10)``
276+
# row_ids: Optional[Union[Sequence[int], Slice]] = None,
277+
# col_ids: Optional[Union[Sequence[int], Slice]] = None,
278+
row_ids: Optional[Sequence[int]] = None,
279+
col_ids: Optional[Sequence[int]] = None,
280+
result_order: Optional[str] = None,
281+
# TODO: batch_size
282+
# TODO: partition,
283+
# TODO: batch_format,
284+
# TODO: platform_config,
285+
) -> pa.Table:
286+
"""
287+
This is a convenience method around ``read``. It iterates the return value from ``read`` and returns a concatenation of all the table-pieces found. Its nominal use is to simply unit-test cases.
288+
"""
289+
return util_arrow.concat_tables(
290+
self.read(
291+
row_ids=row_ids,
292+
col_ids=col_ids,
293+
result_order=result_order,
226294
)
227295
for arrow_tbl in query.df[coords]:
228296
yield arrow_tbl
@@ -238,8 +306,15 @@ def read_as_pandas(self, coords: SOMASparseNdCoordinates) -> Iterator[pd.DataFra
238306
def read_as_pandas_all(
239307
self, coords: Optional[SOMASparseNdCoordinates] = None
240308
) -> pd.DataFrame:
309+
self,
310+
*,
311+
row_ids: Optional[Sequence[int]] = None,
312+
col_ids: Optional[Sequence[int]] = None,
313+
set_index: Optional[bool] = False,
314+
) -> pa.Table:
241315
"""
242316
Return the sparse array as a single Pandas DataFrame containing COO data.
317+
This is a convenience method around ``read_as_pandas``. It iterates the return value from ``read_as_pandas`` and returns a concatenation of all the table-pieces found. Its nominal use is to simply unit-test cases.
243318
"""
244319
if coords is None:
245320
coords = (slice(None),) * self.ndims

apis/python/src/tiledbsoma/util_arrow.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,36 +121,31 @@ def get_arrow_schema_from_tiledb_uri(
121121
return pa.schema(arrow_schema_dict)
122122

123123

124-
def ascii_to_unicode_pyarrow_readback(record_batch: pa.RecordBatch) -> pa.RecordBatch:
124+
def ascii_to_unicode_pyarrow_readback(table: pa.Table) -> pa.Table:
125125
"""
126126
Implements the 'decode on read' part of our ASCII/Unicode logic
127127
"""
128128
# TODO: COMMENT/LINK HEAVILY
129-
names = [ofield.name for ofield in record_batch.schema]
129+
names = [ofield.name for ofield in table.schema]
130130
new_fields = []
131131
for name in names:
132-
old_field = record_batch[name]
133-
if isinstance(old_field, pa.LargeBinaryArray):
132+
old_field = table[name]
133+
if len(old_field) > 0 and isinstance(old_field[0], pa.LargeBinaryScalar):
134134
nfield = pa.array(
135135
[element.as_py().decode("utf-8") for element in old_field]
136136
)
137137
new_fields.append(nfield)
138138
else:
139139
new_fields.append(old_field)
140-
return pa.RecordBatch.from_arrays(new_fields, names=names)
140+
return pa.Table.from_arrays(new_fields, names=names)
141141

142142

143-
def concat_batches(batch_generator: Iterator[pa.RecordBatch]) -> pa.RecordBatch:
143+
def concat_tables(table_generator: Iterator[Any]) -> pa.Table:
144144
"""
145-
Iterates a generator of ``pyarrow.RecordBatch`` (e.g. ``SOMADataFrame.read``) and returns a concatenation of all the record batches found. The nominal use is to simply unit-test cases.
145+
Iterates a generator of ``pyarrow.Table`` (e.g. ``SOMADataFrame.read``) and returns a concatenation of all the table-pieces found. The nominal use is to simply unit-test cases.
146146
"""
147-
batches = []
148-
for batch in batch_generator:
149-
batches.append(batch)
150-
assert len(batches) > 0
151-
names = [field.name for field in batches[0].schema]
152-
arrays = []
153-
for name in names:
154-
array = pa.concat_arrays([batch[name] for batch in batches])
155-
arrays.append(array)
156-
return pa.RecordBatch.from_arrays(arrays, names=names)
147+
tables = []
148+
for table in table_generator:
149+
tables.append(table)
150+
assert len(tables) > 0
151+
return pa.concat_tables(tables)

apis/python/tests/test_soma_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def create_and_populate_dataframe(dataframe: soma.SOMADataFrame) -> None:
2525
pydict["foo"] = [10, 20, 30, 40, 50]
2626
pydict["bar"] = [4.1, 5.2, 6.3, 7.4, 8.5]
2727
pydict["baz"] = ["apple", "ball", "cat", "dog", "egg"]
28-
rb = pa.RecordBatch.from_pydict(pydict)
28+
rb = pa.Table.from_pydict(pydict)
2929
dataframe.write(rb)
3030

3131

0 commit comments

Comments
 (0)