diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 27c5687ee..ae82a482a 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -10,7 +10,7 @@ use num_traits::Float; use num_traits::One; use num_traits::{FromPrimitive, Zero}; -use std::ops::{Add, Div, Mul, Sub}; +use std::ops::{Add, Div, Mul, MulAssign, Sub}; use crate::imp_prelude::*; use crate::numeric_util; @@ -97,6 +97,45 @@ where D: Dimension sum } + /// Return the cumulative product of elements along a given axis. + /// + /// ``` + /// use ndarray::{arr2, Axis}; + /// + /// let a = arr2(&[[1., 2., 3.], + /// [4., 5., 6.]]); + /// + /// // Cumulative product along rows (axis 0) + /// assert_eq!( + /// a.cumprod(Axis(0)), + /// arr2(&[[1., 2., 3.], + /// [4., 10., 18.]]) + /// ); + /// + /// // Cumulative product along columns (axis 1) + /// assert_eq!( + /// a.cumprod(Axis(1)), + /// arr2(&[[1., 2., 6.], + /// [4., 20., 120.]]) + /// ); + /// ``` + /// + /// **Panics** if `axis` is out of bounds. + #[track_caller] + pub fn cumprod(&self, axis: Axis) -> Array + where + A: Clone + Mul + MulAssign, + D: Dimension + RemoveAxis, + { + if axis.0 >= self.ndim() { + panic!("axis is out of bounds for array of dimension"); + } + + let mut result = self.to_owned(); + result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone()); + result + } + /// Return variance of elements in the array. /// /// The variance is computed using the [Welford one-pass diff --git a/tests/numeric.rs b/tests/numeric.rs index 839aba58e..7e6964812 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -75,6 +75,76 @@ fn sum_mean_prod_empty() assert_eq!(a, None); } +#[test] +fn test_cumprod_1d() +{ + let a = array![1, 2, 3, 4]; + let result = a.cumprod(Axis(0)); + assert_eq!(result, array![1, 2, 6, 24]); +} + +#[test] +fn test_cumprod_2d() +{ + let a = array![[1, 2], [3, 4]]; + + let result_axis0 = a.cumprod(Axis(0)); + assert_eq!(result_axis0, array![[1, 2], [3, 8]]); + + let result_axis1 = a.cumprod(Axis(1)); + assert_eq!(result_axis1, array![[1, 2], [3, 12]]); +} + +#[test] +fn test_cumprod_3d() +{ + let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; + + let result_axis0 = a.cumprod(Axis(0)); + assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]); + + let result_axis1 = a.cumprod(Axis(1)); + assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]); + + let result_axis2 = a.cumprod(Axis(2)); + assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]); +} + +#[test] +fn test_cumprod_empty() +{ + // For 2D empty array + let b: Array2 = Array2::zeros((0, 0)); + let result_axis0 = b.cumprod(Axis(0)); + assert_eq!(result_axis0, Array2::zeros((0, 0))); + let result_axis1 = b.cumprod(Axis(1)); + assert_eq!(result_axis1, Array2::zeros((0, 0))); +} + +#[test] +fn test_cumprod_1_element() +{ + // For 1D array with one element + let a = array![5]; + let result = a.cumprod(Axis(0)); + assert_eq!(result, array![5]); + + // For 2D array with one element + let b = array![[5]]; + let result_axis0 = b.cumprod(Axis(0)); + let result_axis1 = b.cumprod(Axis(1)); + assert_eq!(result_axis0, array![[5]]); + assert_eq!(result_axis1, array![[5]]); +} + +#[test] +#[should_panic(expected = "axis is out of bounds for array of dimension")] +fn test_cumprod_axis_out_of_bounds() +{ + let a = array![[1, 2], [3, 4]]; + let _result = a.cumprod(Axis(2)); +} + #[test] #[cfg(feature = "std")] fn var()