@@ -63,7 +63,7 @@ def _get_grid_and_space_dimension(filename: str) -> Tuple[int, int]:
6363 return int (dim ), int (space_dim )
6464
6565
66- def _check_vtk_file (vtk_reader ,
66+ def _check_vtk_file (vtk_grid ,
6767 points ,
6868 space_dim ,
6969 reference_function : Callable [[list ], float ],
@@ -72,9 +72,8 @@ def _check_vtk_file(vtk_reader,
7272 rel_tol = 1e-5
7373 abs_tol = 1e-3
7474
75- output = vtk_reader .GetOutput ()
7675 if not skip_metadata :
77- field_data = output .GetFieldData ()
76+ field_data = vtk_grid .GetFieldData ()
7877 expected_field_data = ["literal" , "string" , "numbers" ]
7978 for i in range (field_data .GetNumberOfArrays ()):
8079 name = field_data .GetAbstractArray (i ).GetName ()
@@ -92,12 +91,12 @@ def _check_vtk_file(vtk_reader,
9291 reference_function .set_time (time_value )
9392
9493 # precompute cell centers
95- num_cells = output .GetNumberOfCells ()
94+ num_cells = vtk_grid .GetNumberOfCells ()
9695 points = array (points )
9796 cell_centers = ndarray (shape = (num_cells , 3 ))
9897 for cell_id in range (num_cells ):
9998 ids = vtk .vtkIdList ()
100- output .GetCellPoints (cell_id , ids )
99+ vtk_grid .GetCellPoints (cell_id , ids )
101100 corner_indices = [ids .GetId (_i ) for _i in range (ids .GetNumberOfIds ())]
102101 cell_centers [cell_id ] = np_sum (points [corner_indices ], axis = 0 )
103102 cell_centers [cell_id ] /= float (ids .GetNumberOfIds ())
@@ -115,7 +114,7 @@ def _compare_data_array(arr, position_call_back):
115114 else :
116115 assert isclose (0.0 , value [comp ], rel_tol = rel_tol , abs_tol = abs_tol )
117116
118- point_data = output .GetPointData ()
117+ point_data = vtk_grid .GetPointData ()
119118 for i in range (point_data .GetNumberOfArrays ()):
120119 name = point_data .GetArrayName (i )
121120 arr = point_data .GetArray (i )
@@ -125,15 +124,15 @@ def _compare_data_array(arr, position_call_back):
125124 point_id = 0
126125 for cell_id in range (num_cells ):
127126 ids = vtk .vtkIdList ()
128- output .GetCellPoints (cell_id , ids )
127+ vtk_grid .GetCellPoints (cell_id , ids )
129128 for _ in range (ids .GetNumberOfIds ()):
130129 assert isclose (arr .GetTuple (point_id )[0 ], float (cell_id ))
131130 point_id += 1
132131 else :
133132 for i in range (arr .GetNumberOfTuples ()):
134133 _compare_data_array (arr , lambda i : points [i ])
135134
136- cell_data = output .GetCellData ()
135+ cell_data = vtk_grid .GetCellData ()
137136 for i in range (cell_data .GetNumberOfArrays ()):
138137 name = cell_data .GetArrayName (i )
139138 arr = cell_data .GetArray (i )
@@ -154,16 +153,42 @@ def _read_pvd_pieces(filename: str) -> List[_TimeStep]:
154153
155154
156155def _test_vtk (filename : str , skip_metadata : bool , reference_function : Callable [[list ], float ]):
156+ def _grid (reader ):
157+ return reader .GetOutput ()
158+
159+ # For readers that potentially return partitioned datasets
160+ def _merged_partitioned_grid (reader ):
161+ output = reader .GetOutput ()
162+ if isinstance (output , vtk .vtkUnstructuredGrid ):
163+ return output
164+ merged = vtk .vtkAppendFilter ()
165+ iter = output .NewIterator ()
166+ iter .InitTraversal ()
167+ while not iter .IsDoneWithTraversal ():
168+ dataset = iter .GetCurrentDataObject ()
169+ if dataset :
170+ merged .AddInputData (dataset )
171+ iter .GoToNextItem ()
172+ merged .Update ()
173+ merged_grid = merged .GetOutput ()
174+ merged_grid .GetFieldData ().PassData (reader .GetOutput ().GetFieldData ())
175+ return merged_grid
176+
157177 def _get_points_from_grid (reader ):
158178 points = reader .GetOutput ().GetPoints ()
159179 return array ([points .GetPoint (i ) for i in range (points .GetNumberOfPoints ())])
160180
181+ def _get_points_from_partioned_grid (reader ):
182+ points = _merged_partitioned_grid (reader ).GetPoints ()
183+ return array ([points .GetPoint (i ) for i in range (points .GetNumberOfPoints ())])
184+
161185 def _get_rectilinear_points (reader ):
162186 output = reader .GetOutput ()
163187 return array ([output .GetPoint (i ) for i in range (output .GetNumberOfPoints ())])
164188
165189 e = VTKErrorObserver ()
166190 ext = splitext (filename )[1 ]
191+ get_grid = _grid
167192 if ext == ".vtu" :
168193 reader = vtk .vtkXMLUnstructuredGridReader ()
169194 point_collector = _get_points_from_grid
@@ -199,7 +224,8 @@ def _get_rectilinear_points(reader):
199224 point_collector = _get_rectilinear_points
200225 elif ext == ".hdf" and "unstructured" in filename :
201226 reader = vtk .vtkHDFReader ()
202- point_collector = _get_points_from_grid
227+ point_collector = _get_points_from_partioned_grid
228+ get_grid = _merged_partitioned_grid
203229 else :
204230 raise NotImplementedError (f"Could not determine suitable reader { filename } " )
205231 reader .AddObserver ("ErrorEvent" , e )
@@ -210,7 +236,7 @@ def _get_rectilinear_points(reader):
210236
211237 _ , space_dim = _get_grid_and_space_dimension (filename )
212238 is_discontinuous = "discontinuous" in filename
213- _check_vtk_file (reader , point_collector (reader ), space_dim , reference_function , skip_metadata , is_discontinuous )
239+ _check_vtk_file (get_grid ( reader ) , point_collector (reader ), space_dim , reference_function , skip_metadata , is_discontinuous )
214240
215241
216242def test (filename : str , skip_metadata : bool = False ) -> int | None :
0 commit comments