Skip to content

Commit c2e253b

Browse files
authored
refactor: add more test setup utils in write_partitioned.rs tests (#2422)
## What changes are proposed in this pull request? Refactor `write_partitioned.rs` test suite to be a bit cleaner (less duplication between tests). This will be useful when we add even more tests to this suite. ## How was this change tested? Refactor. Existing UTs.
1 parent 9decd69 commit c2e253b

1 file changed

Lines changed: 116 additions & 121 deletions

File tree

kernel/tests/write_partitioned.rs

Lines changed: 116 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,14 @@ use delta_kernel::arrow::array::{
1212
use delta_kernel::arrow::datatypes::Schema as ArrowSchema;
1313
use delta_kernel::committer::FileSystemCommitter;
1414
use delta_kernel::engine::arrow_conversion::TryIntoArrow as _;
15-
use delta_kernel::engine::arrow_data::ArrowEngineData;
16-
use delta_kernel::engine::default::executor::tokio::TokioMultiThreadExecutor;
17-
use delta_kernel::engine::default::DefaultEngine;
1815
use delta_kernel::expressions::Scalar;
1916
use delta_kernel::schema::{DataType, StructField, StructType};
2017
use delta_kernel::table_features::ColumnMappingMode;
2118
use delta_kernel::transaction::create_table::create_table;
2219
use delta_kernel::transaction::data_layout::DataLayout;
2320
use delta_kernel::Snapshot;
2421
use rstest::rstest;
25-
use test_utils::{read_scan, test_table_setup_mt};
26-
use url::Url;
27-
28-
// ==============================================================================
29-
// Helpers
30-
// ==============================================================================
31-
32-
async fn write_batch(
33-
snapshot: &Arc<Snapshot>,
34-
engine: &DefaultEngine<TokioMultiThreadExecutor>,
35-
data: RecordBatch,
36-
partition_values: HashMap<String, Scalar>,
37-
) -> Result<Arc<Snapshot>, Box<dyn std::error::Error>> {
38-
let mut txn = snapshot
39-
.clone()
40-
.transaction(Box::new(FileSystemCommitter::new()), engine)?
41-
.with_engine_info("test")
42-
.with_data_change(true);
43-
let write_context = txn.partitioned_write_context(partition_values)?;
44-
let add_meta = engine
45-
.write_parquet(&ArrowEngineData::new(data), &write_context)
46-
.await?;
47-
txn.add_files(add_meta);
48-
let committed = txn.commit(engine)?.unwrap_committed();
49-
Ok(committed.post_commit_snapshot().unwrap().clone())
50-
}
22+
use test_utils::{read_scan, test_table_setup_mt, write_batch_to_table};
5123

5224
// ==============================================================================
5325
// Tests
@@ -64,34 +36,18 @@ async fn test_write_partitioned_normal_values_roundtrip(
6436
#[case] cm_mode: ColumnMappingMode,
6537
) -> Result<(), Box<dyn std::error::Error>> {
6638
// ===== Step 1: Create table and write one row with normal partition values. =====
67-
let (_tmp_dir, table_path, engine) = test_table_setup_mt()?;
68-
let schema = all_types_schema();
69-
let arrow_schema: Arc<ArrowSchema> = Arc::new(schema.as_ref().try_into_arrow()?);
70-
let snapshot = create_all_types_table(&table_path, engine.as_ref(), cm_mode)?;
71-
assert_eq!(snapshot.table_configuration().partition_columns().len(), 13);
72-
73-
let batch = RecordBatch::try_new(arrow_schema, normal_arrow_columns())?;
74-
let snapshot = write_batch(
75-
&snapshot,
76-
engine.as_ref(),
77-
batch,
39+
let (_tmp_dir, table_path, snapshot, engine) = setup_and_write(
40+
all_types_schema(),
41+
PARTITION_COLS,
42+
cm_mode,
43+
normal_arrow_columns(),
7844
normal_partition_values()?,
7945
)
8046
.await?;
47+
assert_eq!(snapshot.table_configuration().partition_columns().len(), 13);
8148

8249
// ===== Step 2: Validate add.path structure in the commit log JSON. =====
83-
let adds = read_add_actions_json(&table_path, 1)?;
84-
assert_eq!(adds.len(), 1, "should have exactly one add action");
85-
let add = &adds[0];
86-
let raw_path = add["path"].as_str().unwrap();
87-
88-
assert!(
89-
!raw_path.contains("://"),
90-
"should produce relative paths, got: {raw_path}"
91-
);
92-
93-
let rel_path = strip_table_root(raw_path, snapshot.table_root());
94-
50+
let (add, rel_path) = read_single_add(&table_path, 1)?;
9551
match cm_mode {
9652
ColumnMappingMode::None => {
9753
// Hive-style path with Hive encoding: colons -> %3A, spaces -> %20.
@@ -108,16 +64,7 @@ async fn test_write_partitioned_normal_values_roundtrip(
10864
assert!(rel_path.ends_with(".parquet"));
10965
}
11066
ColumnMappingMode::Name | ColumnMappingMode::Id => {
111-
// Random 2-char prefix: <2char>/<uuid>.parquet
112-
let segments: Vec<&str> = rel_path.split('/').collect();
113-
assert_eq!(
114-
segments.len(),
115-
2,
116-
"CM on: path should be <prefix>/<file>, got: {rel_path}"
117-
);
118-
assert_eq!(segments[0].len(), 2, "prefix should be 2 chars");
119-
assert!(segments[0].chars().all(|c| c.is_ascii_alphanumeric()));
120-
assert!(segments[1].ends_with(".parquet"));
67+
assert_cm_path(&rel_path);
12168
}
12269
}
12370

@@ -149,16 +96,8 @@ async fn test_write_partitioned_normal_values_roundtrip(
14996
}
15097
}
15198

152-
// ===== Step 4: Read data back via scan and verify all column values. =====
153-
let sorted = read_sorted(&snapshot, engine.clone())?;
154-
assert_normal_values(&sorted);
155-
156-
// ===== Step 5: Checkpoint, reload snapshot from checkpoint, read again. =====
157-
snapshot.checkpoint(engine.as_ref())?;
158-
let table_url = delta_kernel::try_parse_uri(&table_path)?;
159-
let snapshot_after_cp = Snapshot::builder_for(table_url).build(engine.as_ref())?;
160-
let sorted = read_sorted(&snapshot_after_cp, engine)?;
161-
assert_normal_values(&sorted);
99+
// ===== Step 4: Scan and verify values survive checkpoint + reload. =====
100+
verify_and_checkpoint(&snapshot, engine, assert_normal_values)?;
162101

163102
Ok(())
164103
}
@@ -174,47 +113,29 @@ async fn test_write_partitioned_null_values_roundtrip(
174113
#[case] cm_mode: ColumnMappingMode,
175114
) -> Result<(), Box<dyn std::error::Error>> {
176115
// ===== Step 1: Create table and write one row with all-null partition values. =====
177-
let (_tmp_dir, table_path, engine) = test_table_setup_mt()?;
178-
let schema = all_types_schema();
179-
let arrow_schema: Arc<ArrowSchema> = Arc::new(schema.as_ref().try_into_arrow()?);
180-
let snapshot = create_all_types_table(&table_path, engine.as_ref(), cm_mode)?;
181-
182-
let batch = RecordBatch::try_new(arrow_schema, null_arrow_columns())?;
183-
let snapshot = write_batch(&snapshot, engine.as_ref(), batch, null_partition_values()?).await?;
116+
let (_tmp_dir, table_path, snapshot, engine) = setup_and_write(
117+
all_types_schema(),
118+
PARTITION_COLS,
119+
cm_mode,
120+
null_arrow_columns(),
121+
null_partition_values()?,
122+
)
123+
.await?;
184124

185125
// ===== Step 2: Validate add.path structure in the commit log JSON. =====
186-
let adds = read_add_actions_json(&table_path, 1)?;
187-
assert_eq!(adds.len(), 1, "should have exactly one add action");
188-
let add = &adds[0];
189-
let raw_path = add["path"].as_str().unwrap();
190-
191-
assert!(
192-
!raw_path.contains("://"),
193-
"should produce relative paths, got: {raw_path}"
194-
);
195-
196-
let rel_path = strip_table_root(raw_path, snapshot.table_root());
197-
198-
let hdp = "__HIVE_DEFAULT_PARTITION__";
126+
let (add, rel_path) = read_single_add(&table_path, 1)?;
199127
match cm_mode {
200128
ColumnMappingMode::None => {
201129
// Every partition column should use HIVE_DEFAULT_PARTITION in the path.
202-
let expected_prefix = hive_prefix(PARTITION_COLS, hdp);
130+
let expected_prefix = hive_prefix(PARTITION_COLS, "__HIVE_DEFAULT_PARTITION__");
203131
assert!(
204132
rel_path.starts_with(&expected_prefix),
205133
"CM off null: relative path mismatch.\n \
206134
expected: {expected_prefix}<uuid>.parquet\n got: {rel_path}"
207135
);
208136
}
209137
ColumnMappingMode::Name | ColumnMappingMode::Id => {
210-
let segments: Vec<&str> = rel_path.split('/').collect();
211-
assert_eq!(
212-
segments.len(),
213-
2,
214-
"CM on: path should be <prefix>/<file>, got: {rel_path}"
215-
);
216-
assert_eq!(segments[0].len(), 2);
217-
assert!(segments[0].chars().all(|c| c.is_ascii_alphanumeric()));
138+
assert_cm_path(&rel_path);
218139
}
219140
}
220141

@@ -228,16 +149,8 @@ async fn test_write_partitioned_null_values_roundtrip(
228149
);
229150
}
230151

231-
// ===== Step 4: Read data back via scan and verify all partition columns are null. =====
232-
let sorted = read_sorted(&snapshot, engine.clone())?;
233-
assert_all_partition_columns_null(&sorted);
234-
235-
// ===== Step 5: Checkpoint, reload snapshot from checkpoint, read again. =====
236-
snapshot.checkpoint(engine.as_ref())?;
237-
let table_url = delta_kernel::try_parse_uri(&table_path)?;
238-
let snapshot_after_cp = Snapshot::builder_for(table_url).build(engine.as_ref())?;
239-
let sorted = read_sorted(&snapshot_after_cp, engine)?;
240-
assert_all_partition_columns_null(&sorted);
152+
// ===== Step 4: Scan and verify all-null values survive checkpoint + reload. =====
153+
verify_and_checkpoint(&snapshot, engine, assert_all_partition_columns_null)?;
241154

242155
Ok(())
243156
}
@@ -459,6 +372,19 @@ fn assert_all_partition_columns_null(sorted: &RecordBatch) {
459372
}
460373
}
461374

375+
/// Asserts a CM=name/id relative path has the shape `<2char>/<file>.parquet`.
376+
fn assert_cm_path(rel_path: &str) {
377+
let segments: Vec<&str> = rel_path.split('/').collect();
378+
assert_eq!(
379+
segments.len(),
380+
2,
381+
"CM on: path should be <prefix>/<file>, got: {rel_path}"
382+
);
383+
assert_eq!(segments[0].len(), 2, "prefix should be 2 chars");
384+
assert!(segments[0].chars().all(|c| c.is_ascii_alphanumeric()));
385+
assert!(segments[1].ends_with(".parquet"));
386+
}
387+
462388
// ==============================================================================
463389
// Table setup and utility helpers
464390
// ==============================================================================
@@ -471,23 +397,23 @@ fn cm_mode_str(mode: ColumnMappingMode) -> &'static str {
471397
}
472398
}
473399

474-
fn create_all_types_table(
400+
fn create_partitioned_table(
475401
table_path: &str,
476402
engine: &dyn delta_kernel::Engine,
403+
schema: Arc<StructType>,
404+
partition_cols: &[&str],
477405
cm_mode: ColumnMappingMode,
478406
) -> Result<Arc<Snapshot>, Box<dyn std::error::Error>> {
479-
let schema = all_types_schema();
480407
let mut builder = create_table(table_path, schema, "test/1.0")
481-
.with_data_layout(DataLayout::partitioned(PARTITION_COLS.to_vec()));
408+
.with_data_layout(DataLayout::partitioned(partition_cols));
482409
if cm_mode != ColumnMappingMode::None {
483410
builder =
484411
builder.with_table_properties([("delta.columnMapping.mode", cm_mode_str(cm_mode))]);
485412
}
486413
let _ = builder
487414
.build(engine, Box::new(FileSystemCommitter::new()))?
488415
.commit(engine)?;
489-
let table_url = delta_kernel::try_parse_uri(table_path)?;
490-
Ok(Snapshot::builder_for(table_url).build(engine)?)
416+
Ok(Snapshot::builder_for(table_path).build(engine)?)
491417
}
492418

493419
fn read_sorted(
@@ -538,15 +464,67 @@ fn decimal_array(value: i128, precision: u8, scale: i8) -> ArrayRef {
538464
)
539465
}
540466

467+
/// Creates a partitioned table, writes one batch, and returns the post-commit snapshot.
468+
/// Callers can subsequently use [`read_single_add`] with the returned `table_path` to
469+
/// inspect the commit log.
470+
async fn setup_and_write(
471+
schema: Arc<StructType>,
472+
partition_cols: &[&str],
473+
cm_mode: ColumnMappingMode,
474+
arrow_columns: Vec<ArrayRef>,
475+
partition_values: HashMap<String, Scalar>,
476+
) -> Result<
477+
(
478+
tempfile::TempDir,
479+
String,
480+
Arc<Snapshot>,
481+
Arc<dyn delta_kernel::Engine>,
482+
),
483+
Box<dyn std::error::Error>,
484+
> {
485+
let (tmp_dir, table_path, engine) = test_table_setup_mt()?;
486+
let arrow_schema: Arc<ArrowSchema> = Arc::new(schema.as_ref().try_into_arrow()?);
487+
let snapshot = create_partitioned_table(
488+
&table_path,
489+
engine.as_ref(),
490+
schema,
491+
partition_cols,
492+
cm_mode,
493+
)?;
494+
495+
let batch = RecordBatch::try_new(arrow_schema, arrow_columns)?;
496+
let snapshot =
497+
write_batch_to_table(&snapshot, engine.as_ref(), batch, partition_values).await?;
498+
499+
Ok((
500+
tmp_dir,
501+
table_path,
502+
snapshot,
503+
engine as Arc<dyn delta_kernel::Engine>,
504+
))
505+
}
506+
507+
/// Reads the snapshot, runs `assert_fn` on the sorted scan result, then checkpoints,
508+
/// reloads a fresh snapshot from disk, and runs `assert_fn` again on the reloaded scan.
509+
fn verify_and_checkpoint(
510+
snapshot: &Arc<Snapshot>,
511+
engine: Arc<dyn delta_kernel::Engine>,
512+
assert_fn: fn(&RecordBatch),
513+
) -> Result<(), Box<dyn std::error::Error>> {
514+
let sorted = read_sorted(snapshot, engine.clone())?;
515+
assert_fn(&sorted);
516+
517+
snapshot.checkpoint(engine.as_ref())?;
518+
let reloaded = Snapshot::builder_for(snapshot.table_root()).build(engine.as_ref())?;
519+
let sorted = read_sorted(&reloaded, engine)?;
520+
assert_fn(&sorted);
521+
Ok(())
522+
}
523+
541524
// ==============================================================================
542525
// JSON commit log helpers
543526
// ==============================================================================
544527

545-
/// Returns the relative path portion from an add.path, which is always relative.
546-
fn strip_table_root(path: &str, _table_root: &Url) -> String {
547-
path.to_string()
548-
}
549-
550528
/// Builds an unescaped Hive-style path prefix like `col1=val/col2=val/`.
551529
/// Only correct when `value` contains no Hive-special characters.
552530
fn hive_prefix(cols: &[&str], value: &str) -> String {
@@ -572,3 +550,20 @@ fn read_add_actions_json(
572550
.filter_map(|v| v.get("add").cloned())
573551
.collect())
574552
}
553+
554+
/// Reads the single add action from the commit at `version` and returns the action
555+
/// together with its path. Asserts the path is relative (no `scheme://` prefix).
556+
fn read_single_add(
557+
table_path: &str,
558+
version: u64,
559+
) -> Result<(serde_json::Value, String), Box<dyn std::error::Error>> {
560+
let adds = read_add_actions_json(table_path, version)?;
561+
assert_eq!(adds.len(), 1, "should have exactly one add action");
562+
let add = adds.into_iter().next().unwrap();
563+
let rel_path = add["path"].as_str().unwrap().to_string();
564+
assert!(
565+
!rel_path.contains("://"),
566+
"should produce relative paths, got: {rel_path}"
567+
);
568+
Ok((add, rel_path))
569+
}

0 commit comments

Comments
 (0)