Skip to content

Commit b3def82

Browse files
authored
Introduce new Layout type for clarity (#96)
1 parent 76f9726 commit b3def82

File tree

2 files changed

+43
-20
lines changed

2 files changed

+43
-20
lines changed

src/npy/header.rs

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,26 @@ impl From<FormatHeaderError> for WriteHeaderError {
316316
}
317317
}
318318

319+
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
320+
pub enum Layout {
321+
/// Standard layout (C order).
322+
Standard,
323+
/// Fortran layout.
324+
Fortran,
325+
}
326+
327+
impl Layout {
328+
/// Returns `true` if the layout is [`Fortran`](Self::Fortran).
329+
#[inline]
330+
pub fn is_fortran(&self) -> bool {
331+
matches!(*self, Layout::Fortran)
332+
}
333+
}
334+
319335
#[derive(Clone, Debug)]
320336
pub struct Header {
321337
pub type_descriptor: PyValue,
322-
pub fortran_order: bool,
338+
pub layout: Layout,
323339
pub shape: Vec<usize>,
324340
}
325341

@@ -333,7 +349,7 @@ impl Header {
333349
fn from_py_value(value: PyValue) -> Result<Self, ParseHeaderError> {
334350
if let PyValue::Dict(dict) = value {
335351
let mut type_descriptor: Option<PyValue> = None;
336-
let mut fortran_order: Option<bool> = None;
352+
let mut is_fortran: Option<bool> = None;
337353
let mut shape: Option<Vec<usize>> = None;
338354
for (key, value) in dict {
339355
match key {
@@ -342,7 +358,7 @@ impl Header {
342358
}
343359
PyValue::String(ref k) if k == "fortran_order" => {
344360
if let PyValue::Boolean(b) = value {
345-
fortran_order = Some(b);
361+
is_fortran = Some(b);
346362
} else {
347363
return Err(ParseHeaderError::IllegalValue {
348364
key: "fortran_order".to_owned(),
@@ -370,12 +386,19 @@ impl Header {
370386
k => return Err(ParseHeaderError::UnknownKey(k)),
371387
}
372388
}
373-
match (type_descriptor, fortran_order, shape) {
374-
(Some(type_descriptor), Some(fortran_order), Some(shape)) => Ok(Header {
375-
type_descriptor,
376-
fortran_order,
377-
shape,
378-
}),
389+
match (type_descriptor, is_fortran, shape) {
390+
(Some(type_descriptor), Some(is_fortran), Some(shape)) => {
391+
let layout = if is_fortran {
392+
Layout::Fortran
393+
} else {
394+
Layout::Standard
395+
};
396+
Ok(Header {
397+
type_descriptor,
398+
layout,
399+
shape,
400+
})
401+
}
379402
(None, _, _) => Err(ParseHeaderError::MissingKey("descr".to_owned())),
380403
(_, None, _) => Err(ParseHeaderError::MissingKey("fortran_order".to_owned())),
381404
(_, _, None) => Err(ParseHeaderError::MissingKey("shaper".to_owned())),
@@ -433,7 +456,7 @@ impl Header {
433456
),
434457
(
435458
PyValue::String("fortran_order".into()),
436-
PyValue::Boolean(self.fortran_order),
459+
PyValue::Boolean(self.layout.is_fortran()),
437460
),
438461
(
439462
PyValue::String("shape".into()),

src/npy/mod.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod elements;
22
pub mod header;
33

44
use self::header::{
5-
FormatHeaderError, Header, ParseHeaderError, ReadHeaderError, WriteHeaderError,
5+
FormatHeaderError, Header, Layout, ParseHeaderError, ReadHeaderError, WriteHeaderError,
66
};
77
use ndarray::prelude::*;
88
use ndarray::{Data, DataOwned, IntoDimension};
@@ -170,7 +170,7 @@ where
170170
.expect("overflow converting length of data to u64");
171171
Header {
172172
type_descriptor: A::type_descriptor(),
173-
fortran_order: false,
173+
layout: Layout::Standard,
174174
shape: dim.as_array_view().to_vec(),
175175
}
176176
.write(file)?;
@@ -329,10 +329,10 @@ where
329329
D: Dimension,
330330
{
331331
fn write_npy<W: io::Write>(&self, mut writer: W) -> Result<(), WriteNpyError> {
332-
let write_contiguous = |mut writer: W, fortran_order: bool| {
332+
let write_contiguous = |mut writer: W, layout: Layout| {
333333
Header {
334334
type_descriptor: A::type_descriptor(),
335-
fortran_order,
335+
layout,
336336
shape: self.shape().to_owned(),
337337
}
338338
.write(&mut writer)?;
@@ -341,13 +341,13 @@ where
341341
Ok(())
342342
};
343343
if self.is_standard_layout() {
344-
write_contiguous(writer, false)
344+
write_contiguous(writer, Layout::Standard)
345345
} else if self.view().reversed_axes().is_standard_layout() {
346-
write_contiguous(writer, true)
346+
write_contiguous(writer, Layout::Fortran)
347347
} else {
348348
Header {
349349
type_descriptor: A::type_descriptor(),
350-
fortran_order: false,
350+
layout: Layout::Standard,
351351
shape: self.shape().to_owned(),
352352
}
353353
.write(&mut writer)?;
@@ -577,7 +577,7 @@ where
577577
let ndim = shape.ndim();
578578
let len = shape_length_checked::<A>(&shape).ok_or(ReadNpyError::LengthOverflow)?;
579579
let data = A::read_to_end_exact_vec(&mut reader, &header.type_descriptor, len)?;
580-
ArrayBase::from_shape_vec(shape.set_f(header.fortran_order), data)
580+
ArrayBase::from_shape_vec(shape.set_f(header.layout.is_fortran()), data)
581581
.unwrap()
582582
.into_dimensionality()
583583
.map_err(|_| ReadNpyError::WrongNdim(D::NDIM, ndim))
@@ -821,7 +821,7 @@ where
821821
let ndim = shape.ndim();
822822
let len = shape_length_checked::<A>(&shape).ok_or(ViewNpyError::LengthOverflow)?;
823823
let data = A::bytes_as_slice(reader, &header.type_descriptor, len)?;
824-
ArrayView::from_shape(shape.set_f(header.fortran_order), data)
824+
ArrayView::from_shape(shape.set_f(header.layout.is_fortran()), data)
825825
.unwrap()
826826
.into_dimensionality()
827827
.map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim))
@@ -841,7 +841,7 @@ where
841841
let len = shape_length_checked::<A>(&shape).ok_or(ViewNpyError::LengthOverflow)?;
842842
let mid = buf.len() - reader.len();
843843
let data = A::bytes_as_mut_slice(&mut buf[mid..], &header.type_descriptor, len)?;
844-
ArrayViewMut::from_shape(shape.set_f(header.fortran_order), data)
844+
ArrayViewMut::from_shape(shape.set_f(header.layout.is_fortran()), data)
845845
.unwrap()
846846
.into_dimensionality()
847847
.map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim))

0 commit comments

Comments
 (0)