From 3986824d1ed73a5c0763d2075410c087464ec74e Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Tue, 21 May 2024 23:13:48 -0400 Subject: [PATCH 1/4] Adds `triu` and `tril` methods that mimic NumPy. Includes branched implementations for f- and c-order arrays. --- src/lib.rs | 3 + src/tri.rs | 288 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 src/tri.rs diff --git a/src/lib.rs b/src/lib.rs index bfeb6835b..d75d65faa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1616,3 +1616,6 @@ pub(crate) fn is_aligned(ptr: *const T) -> bool { (ptr as usize) % ::std::mem::align_of::() == 0 } + +// Triangular constructors +mod tri; diff --git a/src/tri.rs b/src/tri.rs new file mode 100644 index 000000000..94846add4 --- /dev/null +++ b/src/tri.rs @@ -0,0 +1,288 @@ +// Copyright 2014-2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::cmp::{max, min}; + +use num_traits::Zero; + +use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip}; + +impl ArrayBase +where + S: Data, + D: Dimension, + A: Clone + Zero, + D::Smaller: Copy, +{ + /// Upper triangular of an array. + /// + /// Return a copy of the array with elements below the *k*-th diagonal zeroed. + /// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes. + /// For 0D and 1D arrays, `triu` will return an unchanged clone. + /// + /// See also [`ArrayBase::tril`] + /// + /// ``` + /// use ndarray::array; + /// + /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + /// let res = arr.triu(0); + /// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]); + /// ``` + pub fn triu(&self, k: isize) -> Array + { + match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) { + true => { + let n = self.ndim(); + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.tril(-k); + tril.swap_axes(n - 2, n - 1); + + tril + } + false => { + let mut res = Array::zeros(self.raw_dim()); + Zip::indexed(self.rows()) + .and(res.rows_mut()) + .for_each(|i, src, mut dst| { + let row_num = i.into_dimension().last_elem(); + let lower = max(row_num as isize + k, 0); + dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); + }); + + res + } + } + } + + /// Lower triangular of an array. + /// + /// Return a copy of the array with elements above the *k*-th diagonal zeroed. + /// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes. + /// For 0D and 1D arrays, `tril` will return an unchanged clone. + /// + /// See also [`ArrayBase::triu`] + /// + /// ``` + /// use ndarray::array; + /// + /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + /// let res = arr.tril(0); + /// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); + /// ``` + pub fn tril(&self, k: isize) -> Array + { + match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) { + true => { + let n = self.ndim(); + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.triu(-k); + tril.swap_axes(n - 2, n - 1); + + tril + } + false => { + let mut res = Array::zeros(self.raw_dim()); + Zip::indexed(self.rows()) + .and(res.rows_mut()) + .for_each(|i, src, mut dst| { + // This ncols must go inside the loop to avoid panic on 1D arrays. + // Statistically-neglible difference in performance vs defining ncols at top. + let ncols = src.len_of(Axis(src.ndim() - 1)) as isize; + let row_num = i.into_dimension().last_elem(); + let upper = min(row_num as isize + k, ncols) + 1; + dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); + }); + + res + } + } + } +} + +#[cfg(test)] +mod tests +{ + use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder}; + use std::vec; + + #[test] + fn test_keep_order() + { + let x = Array2::::ones((3, 3).f()); + let res = x.triu(0); + assert!(dimension::is_layout_f(&res.dim, &res.strides)); + + let res = x.tril(0); + assert!(dimension::is_layout_f(&res.dim, &res.strides)); + } + + #[test] + fn test_0d() + { + let x = Array0::::ones(()); + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.tril(0); + assert_eq!(res, x); + + let x = Array0::::ones(().f()); + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.tril(0); + assert_eq!(res, x); + } + + #[test] + fn test_1d() + { + let x = array![1, 2, 3]; + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.triu(0); + assert_eq!(res, x); + + let x = Array1::::ones(3.f()); + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.triu(0); + assert_eq!(res, x); + } + + #[test] + fn test_2d() + { + let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + + // Upper + let res = x.triu(0); + assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]); + + // Lower + let res = x.tril(0); + assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); + + let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap(); + + // Upper + let res = x.triu(0); + assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]); + + // Lower + let res = x.tril(0); + assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); + } + + #[test] + fn test_3d() + { + let x = array![ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]] + ]; + + // Upper + let res = x.triu(0); + assert_eq!( + res, + array![ + [[1, 2, 3], [0, 5, 6], [0, 0, 9]], + [[10, 11, 12], [0, 14, 15], [0, 0, 18]], + [[19, 20, 21], [0, 23, 24], [0, 0, 27]] + ] + ); + + // Lower + let res = x.tril(0); + assert_eq!( + res, + array![ + [[1, 0, 0], [4, 5, 0], [7, 8, 9]], + [[10, 0, 0], [13, 14, 0], [16, 17, 18]], + [[19, 0, 0], [22, 23, 0], [25, 26, 27]] + ] + ); + + let x = Array3::from_shape_vec( + (3, 3, 3).f(), + vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27], + ) + .unwrap(); + + // Upper + let res = x.triu(0); + assert_eq!( + res, + array![ + [[1, 2, 3], [0, 5, 6], [0, 0, 9]], + [[10, 11, 12], [0, 14, 15], [0, 0, 18]], + [[19, 20, 21], [0, 23, 24], [0, 0, 27]] + ] + ); + + // Lower + let res = x.tril(0); + assert_eq!( + res, + array![ + [[1, 0, 0], [4, 5, 0], [7, 8, 9]], + [[10, 0, 0], [13, 14, 0], [16, 17, 18]], + [[19, 0, 0], [22, 23, 0], [25, 26, 27]] + ] + ); + } + + #[test] + fn test_off_axis() + { + let x = array![ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]] + ]; + + let res = x.triu(1); + assert_eq!( + res, + array![ + [[0, 2, 3], [0, 0, 6], [0, 0, 0]], + [[0, 11, 12], [0, 0, 15], [0, 0, 0]], + [[0, 20, 21], [0, 0, 24], [0, 0, 0]] + ] + ); + + let res = x.triu(-1); + assert_eq!( + res, + array![ + [[1, 2, 3], [4, 5, 6], [0, 8, 9]], + [[10, 11, 12], [13, 14, 15], [0, 17, 18]], + [[19, 20, 21], [22, 23, 24], [0, 26, 27]] + ] + ); + } + + #[test] + fn test_odd_shape() + { + let x = array![[1, 2, 3], [4, 5, 6]]; + let res = x.triu(0); + assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); + + let x = array![[1, 2], [3, 4], [5, 6]]; + let res = x.triu(0); + assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]); + } +} From 449ad0e08aab60d1a574bf844d9da41ded2d2729 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Tue, 21 May 2024 23:20:05 -0400 Subject: [PATCH 2/4] Uses alloc:: instead of std:: for vec! import --- src/tri.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tri.rs b/src/tri.rs index 94846add4..bccc5ec87 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -111,7 +111,7 @@ where mod tests { use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder}; - use std::vec; + use alloc::vec; #[test] fn test_keep_order() From 9e9c3f9710c718c68eb382d6ef7032bea91618bb Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Fri, 24 May 2024 23:06:13 -0400 Subject: [PATCH 3/4] Uses initial check to clarify logic --- src/tri.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index bccc5ec87..0d51bc255 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -36,7 +36,10 @@ where /// ``` pub fn triu(&self, k: isize) -> Array { - match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) { + if self.ndim() <= 1 { + return self.to_owned(); + } + match is_layout_f(&self.dim, &self.strides) { true => { let n = self.ndim(); let mut x = self.view(); @@ -78,7 +81,10 @@ where /// ``` pub fn tril(&self, k: isize) -> Array { - match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) { + if self.ndim() <= 1 { + return self.to_owned(); + } + match is_layout_f(&self.dim, &self.strides) { true => { let n = self.ndim(); let mut x = self.view(); @@ -90,12 +96,12 @@ where } false => { let mut res = Array::zeros(self.raw_dim()); + let ncols = self.len_of(Axis(self.ndim() - 1)) as isize; Zip::indexed(self.rows()) .and(res.rows_mut()) .for_each(|i, src, mut dst| { // This ncols must go inside the loop to avoid panic on 1D arrays. // Statistically-neglible difference in performance vs defining ncols at top. - let ncols = src.len_of(Axis(src.ndim() - 1)) as isize; let row_num = i.into_dimension().last_elem(); let upper = min(row_num as isize + k, ncols) + 1; dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); From 84f0c80d9d95d47eff6a28f4dd1cc4c4a5cf8171 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Sat, 25 May 2024 15:26:19 -0400 Subject: [PATCH 4/4] Removes unecessary comment --- src/tri.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index 0d51bc255..4eab9e105 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -100,8 +100,6 @@ where Zip::indexed(self.rows()) .and(res.rows_mut()) .for_each(|i, src, mut dst| { - // This ncols must go inside the loop to avoid panic on 1D arrays. - // Statistically-neglible difference in performance vs defining ncols at top. let row_num = i.into_dimension().last_elem(); let upper = min(row_num as isize + k, ncols) + 1; dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));