diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index d84b0e7b8..6238d0ed1 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -3,6 +3,7 @@ use crate::imp_prelude::*; use crate::IntoDimension; use crate::Layout; use crate::NdProducer; +use crate::Slice; /// Window producer and iterable /// @@ -24,16 +25,19 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> { let mut unit_stride = D::zeros(ndim); unit_stride.slice_mut().fill(1); - + Windows::new_with_stride(a, window, unit_stride) } - pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, window_size: E, strides: E) -> Self + pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, window_size: E, axis_strides: E) -> Self where E: IntoDimension, { let window = window_size.into_dimension(); - let strides_d = strides.into_dimension(); + + let strides = axis_strides.into_dimension(); + let window_strides = a.strides.clone(); + ndassert!( a.ndim() == window.ndim(), concat!( @@ -44,45 +48,35 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> { a.ndim(), a.shape() ); + ndassert!( - a.ndim() == strides_d.ndim(), + a.ndim() == strides.ndim(), concat!( "Stride dimension {} does not match array dimension {} ", "(with array of shape {:?})" ), - strides_d.ndim(), + strides.ndim(), a.ndim(), a.shape() ); - let mut size = a.dim; - for ((sz, &ws), &stride) in size - .slice_mut() - .iter_mut() - .zip(window.slice()) - .zip(strides_d.slice()) - { - assert_ne!(ws, 0, "window-size must not be zero!"); - assert_ne!(stride, 0, "stride cannot have a dimension as zero!"); - // cannot use std::cmp::max(0, ..) since arithmetic underflow panics - *sz = if *sz < ws { - 0 - } else { - ((*sz - (ws - 1) - 1) / stride) + 1 - }; - } - let window_strides = a.strides.clone(); - let mut array_strides = a.strides.clone(); - for (arr_stride, ix_stride) in array_strides.slice_mut().iter_mut().zip(strides_d.slice()) { - *arr_stride *= ix_stride; - } + let mut base = a; + base.slice_each_axis_inplace(|ax_desc| { + let len = ax_desc.len; + let wsz = window[ax_desc.axis.index()]; + let stride = strides[ax_desc.axis.index()]; - unsafe { - Windows { - base: ArrayView::new(a.ptr, size, array_strides), - window, - strides: window_strides, + if len < wsz { + Slice::new(0, Some(0), 1) + } else { + Slice::new(0, Some((len - wsz + 1) as isize), stride as isize) } + }); + + Windows { + base, + window, + strides: window_strides, } } } diff --git a/tests/windows.rs b/tests/windows.rs index 2c928aaef..095976eaa 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -302,3 +302,31 @@ fn test_window_neg_stride() { answer.iter() ); } + +#[test] +fn test_windows_with_stride_on_inverted_axis() { + let mut array = Array::from_iter(1..17).into_shape((4, 4)).unwrap(); + + // inverting axis results in negative stride + array.invert_axis(Axis(0)); + itertools::assert_equal( + array.windows_with_stride((2, 2), (2,2)), + vec![ + arr2(&[[13, 14], [9, 10]]), + arr2(&[[15, 16], [11, 12]]), + arr2(&[[5, 6], [1, 2]]), + arr2(&[[7, 8], [3, 4]]), + ], + ); + + array.invert_axis(Axis(1)); + itertools::assert_equal( + array.windows_with_stride((2, 2), (2,2)), + vec![ + arr2(&[[16, 15], [12, 11]]), + arr2(&[[14, 13], [10, 9]]), + arr2(&[[8, 7], [4, 3]]), + arr2(&[[6, 5], [2, 1]]), + ], + ); +} \ No newline at end of file