11extern crate approx;
22use std:: f64;
3- use ndarray:: { Array1 , array } ;
3+ use ndarray:: { array , Axis , aview1 , aview2 , aview0 , arr0 , arr1 , arr2 , Array , Array1 , Array2 , Array3 } ;
44use approx:: abs_diff_eq;
55
66#[ test]
@@ -32,4 +32,172 @@ fn test_mean_with_array_of_floats() {
3232 // Computed using NumPy
3333 let expected_mean = 0.5475494059146699 ;
3434 abs_diff_eq ! ( a. mean( ) . unwrap( ) , expected_mean, epsilon = f64 :: EPSILON ) ;
35- }
35+ }
36+
37+ #[ test]
38+ fn sum_mean ( )
39+ {
40+ let a = arr2 ( & [ [ 1. , 2. ] , [ 3. , 4. ] ] ) ;
41+ assert_eq ! ( a. sum_axis( Axis ( 0 ) ) , arr1( & [ 4. , 6. ] ) ) ;
42+ assert_eq ! ( a. sum_axis( Axis ( 1 ) ) , arr1( & [ 3. , 7. ] ) ) ;
43+ assert_eq ! ( a. mean_axis( Axis ( 0 ) ) , Some ( arr1( & [ 2. , 3. ] ) ) ) ;
44+ assert_eq ! ( a. mean_axis( Axis ( 1 ) ) , Some ( arr1( & [ 1.5 , 3.5 ] ) ) ) ;
45+ assert_eq ! ( a. sum_axis( Axis ( 1 ) ) . sum_axis( Axis ( 0 ) ) , arr0( 10. ) ) ;
46+ assert_eq ! ( a. view( ) . mean_axis( Axis ( 1 ) ) . unwrap( ) , aview1( & [ 1.5 , 3.5 ] ) ) ;
47+ assert_eq ! ( a. sum( ) , 10. ) ;
48+ }
49+
50+ #[ test]
51+ fn sum_mean_empty ( ) {
52+ assert_eq ! ( Array3 :: <f32 >:: ones( ( 2 , 0 , 3 ) ) . sum( ) , 0. ) ;
53+ assert_eq ! ( Array1 :: <f32 >:: ones( 0 ) . sum_axis( Axis ( 0 ) ) , arr0( 0. ) ) ;
54+ assert_eq ! (
55+ Array3 :: <f32 >:: ones( ( 2 , 0 , 3 ) ) . sum_axis( Axis ( 1 ) ) ,
56+ Array :: zeros( ( 2 , 3 ) ) ,
57+ ) ;
58+ let a = Array1 :: < f32 > :: ones ( 0 ) . mean_axis ( Axis ( 0 ) ) ;
59+ assert_eq ! ( a, None ) ;
60+ let a = Array3 :: < f32 > :: ones ( ( 2 , 0 , 3 ) ) . mean_axis ( Axis ( 1 ) ) ;
61+ assert_eq ! ( a, None ) ;
62+ }
63+
64+ #[ test]
65+ fn var_axis ( ) {
66+ let a = array ! [
67+ [
68+ [ -9.76 , -0.38 , 1.59 , 6.23 ] ,
69+ [ -8.57 , -9.27 , 5.76 , 6.01 ] ,
70+ [ -9.54 , 5.09 , 3.21 , 6.56 ] ,
71+ ] ,
72+ [
73+ [ 8.23 , -9.63 , 3.76 , -3.48 ] ,
74+ [ -5.46 , 5.86 , -2.81 , 1.35 ] ,
75+ [ -1.08 , 4.66 , 8.34 , -0.73 ] ,
76+ ] ,
77+ ] ;
78+ assert ! ( a. var_axis( Axis ( 0 ) , 1.5 ) . all_close(
79+ & aview2( & [
80+ [ 3.236401e+02 , 8.556250e+01 , 4.708900e+00 , 9.428410e+01 ] ,
81+ [ 9.672100e+00 , 2.289169e+02 , 7.344490e+01 , 2.171560e+01 ] ,
82+ [ 7.157160e+01 , 1.849000e-01 , 2.631690e+01 , 5.314410e+01 ]
83+ ] ) ,
84+ 1e-4 ,
85+ ) ) ;
86+ assert ! ( a. var_axis( Axis ( 1 ) , 1.7 ) . all_close(
87+ & aview2( & [
88+ [ 0.61676923 , 80.81092308 , 6.79892308 , 0.11789744 ] ,
89+ [ 75.19912821 , 114.25235897 , 48.32405128 , 9.03020513 ] ,
90+ ] ) ,
91+ 1e-8 ,
92+ ) ) ;
93+ assert ! ( a. var_axis( Axis ( 2 ) , 2.3 ) . all_close(
94+ & aview2( & [
95+ [ 79.64552941 , 129.09663235 , 95.98929412 ] ,
96+ [ 109.64952941 , 43.28758824 , 36.27439706 ] ,
97+ ] ) ,
98+ 1e-8 ,
99+ ) ) ;
100+
101+ let b = array ! [ [ 1.1 , 2.3 , 4.7 ] ] ;
102+ assert ! ( b. var_axis( Axis ( 0 ) , 0. ) . all_close( & aview1( & [ 0. , 0. , 0. ] ) , 1e-12 ) ) ;
103+ assert ! ( b. var_axis( Axis ( 1 ) , 0. ) . all_close( & aview1( & [ 2.24 ] ) , 1e-12 ) ) ;
104+
105+ let c = array ! [ [ ] , [ ] ] ;
106+ assert_eq ! ( c. var_axis( Axis ( 0 ) , 0. ) , aview1( & [ ] ) ) ;
107+
108+ let d = array ! [ 1.1 , 2.7 , 3.5 , 4.9 ] ;
109+ assert ! ( d. var_axis( Axis ( 0 ) , 0. ) . all_close( & aview0( & 1.8875 ) , 1e-12 ) ) ;
110+ }
111+
112+ #[ test]
113+ fn std_axis ( ) {
114+ let a = array ! [
115+ [
116+ [ 0.22935481 , 0.08030619 , 0.60827517 , 0.73684379 ] ,
117+ [ 0.90339851 , 0.82859436 , 0.64020362 , 0.2774583 ] ,
118+ [ 0.44485313 , 0.63316367 , 0.11005111 , 0.08656246 ]
119+ ] ,
120+ [
121+ [ 0.28924665 , 0.44082454 , 0.59837736 , 0.41014531 ] ,
122+ [ 0.08382316 , 0.43259439 , 0.1428889 , 0.44830176 ] ,
123+ [ 0.51529756 , 0.70111616 , 0.20799415 , 0.91851457 ]
124+ ] ,
125+ ] ;
126+ assert ! ( a. std_axis( Axis ( 0 ) , 1.5 ) . all_close(
127+ & aview2( & [
128+ [ 0.05989184 , 0.36051836 , 0.00989781 , 0.32669847 ] ,
129+ [ 0.81957535 , 0.39599997 , 0.49731472 , 0.17084346 ] ,
130+ [ 0.07044443 , 0.06795249 , 0.09794304 , 0.83195211 ] ,
131+ ] ) ,
132+ 1e-4 ,
133+ ) ) ;
134+ assert ! ( a. std_axis( Axis ( 1 ) , 1.7 ) . all_close(
135+ & aview2( & [
136+ [ 0.42698655 , 0.48139215 , 0.36874991 , 0.41458724 ] ,
137+ [ 0.26769097 , 0.18941435 , 0.30555015 , 0.35118674 ] ,
138+ ] ) ,
139+ 1e-8 ,
140+ ) ) ;
141+ assert ! ( a. std_axis( Axis ( 2 ) , 2.3 ) . all_close(
142+ & aview2( & [
143+ [ 0.41117907 , 0.37130425 , 0.35332388 ] ,
144+ [ 0.16905862 , 0.25304841 , 0.39978276 ] ,
145+ ] ) ,
146+ 1e-8 ,
147+ ) ) ;
148+
149+ let b = array ! [ [ 100000. , 1. , 0.01 ] ] ;
150+ assert ! ( b. std_axis( Axis ( 0 ) , 0. ) . all_close( & aview1( & [ 0. , 0. , 0. ] ) , 1e-12 ) ) ;
151+ assert ! (
152+ b. std_axis( Axis ( 1 ) , 0. ) . all_close( & aview1( & [ 47140.214021552769 ] ) , 1e-6 ) ,
153+ ) ;
154+
155+ let c = array ! [ [ ] , [ ] ] ;
156+ assert_eq ! ( c. std_axis( Axis ( 0 ) , 0. ) , aview1( & [ ] ) ) ;
157+ }
158+
159+ #[ test]
160+ #[ should_panic]
161+ fn var_axis_negative_ddof ( ) {
162+ let a = array ! [ 1. , 2. , 3. ] ;
163+ a. var_axis ( Axis ( 0 ) , -1. ) ;
164+ }
165+
166+ #[ test]
167+ #[ should_panic]
168+ fn var_axis_too_large_ddof ( ) {
169+ let a = array ! [ 1. , 2. , 3. ] ;
170+ a. var_axis ( Axis ( 0 ) , 4. ) ;
171+ }
172+
173+ #[ test]
174+ fn var_axis_nan_ddof ( ) {
175+ let a = Array2 :: < f64 > :: zeros ( ( 2 , 3 ) ) ;
176+ let v = a. var_axis ( Axis ( 1 ) , :: std:: f64:: NAN ) ;
177+ assert_eq ! ( v. shape( ) , & [ 2 ] ) ;
178+ v. mapv ( |x| assert ! ( x. is_nan( ) ) ) ;
179+ }
180+
181+ #[ test]
182+ fn var_axis_empty_axis ( ) {
183+ let a = Array2 :: < f64 > :: zeros ( ( 2 , 0 ) ) ;
184+ let v = a. var_axis ( Axis ( 1 ) , 0. ) ;
185+ assert_eq ! ( v. shape( ) , & [ 2 ] ) ;
186+ v. mapv ( |x| assert ! ( x. is_nan( ) ) ) ;
187+ }
188+
189+ #[ test]
190+ #[ should_panic]
191+ fn std_axis_bad_dof ( ) {
192+ let a = array ! [ 1. , 2. , 3. ] ;
193+ a. std_axis ( Axis ( 0 ) , 4. ) ;
194+ }
195+
196+ #[ test]
197+ fn std_axis_empty_axis ( ) {
198+ let a = Array2 :: < f64 > :: zeros ( ( 2 , 0 ) ) ;
199+ let v = a. std_axis ( Axis ( 1 ) , 0. ) ;
200+ assert_eq ! ( v. shape( ) , & [ 2 ] ) ;
201+ v. mapv ( |x| assert ! ( x. is_nan( ) ) ) ;
202+ }
203+
0 commit comments