Skip to content

Commit d728dfb

Browse files
authored
Add matrix operations (rust-lang#326)
1 parent eee8bd0 commit d728dfb

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed

src/math/matrix_ops.rs

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
// Basic matrix operations using row vectors wrapped in column vectors as matrices.
2+
// Supports i32, should be interchangeable for other types.
3+
// Wikipedia reference: https://www.wikiwand.com/en/Matrix_(mathematics)
4+
5+
pub fn matrix_add(summand0: &[Vec<i32>], summand1: &[Vec<i32>]) -> Vec<Vec<i32>> {
6+
// Add two matrices of identical dimensions
7+
let mut result: Vec<Vec<i32>> = vec![];
8+
if summand0.len() != summand1.len() {
9+
panic!("Matrix dimensions do not match");
10+
}
11+
for row in 0..summand0.len() {
12+
if summand0[row].len() != summand1[row].len() {
13+
panic!("Matrix dimensions do not match");
14+
}
15+
result.push(vec![]);
16+
for column in 0..summand1[0].len() {
17+
result[row].push(summand0[row][column] + summand1[row][column]);
18+
}
19+
}
20+
result
21+
}
22+
23+
pub fn matrix_subtract(minuend: &[Vec<i32>], subtrahend: &[Vec<i32>]) -> Vec<Vec<i32>> {
24+
// Subtract one matrix from another. They need to have identical dimensions.
25+
let mut result: Vec<Vec<i32>> = vec![];
26+
if minuend.len() != subtrahend.len() {
27+
panic!("Matrix dimensions do not match");
28+
}
29+
for row in 0..minuend.len() {
30+
if minuend[row].len() != subtrahend[row].len() {
31+
panic!("Matrix dimensions do not match");
32+
}
33+
result.push(vec![]);
34+
for column in 0..subtrahend[0].len() {
35+
result[row].push(minuend[row][column] - subtrahend[row][column]);
36+
}
37+
}
38+
result
39+
}
40+
41+
// Disable cargo clippy warnings about needless range loops.
42+
// As the iterating variable is used as index while multiplying,
43+
// using the item itself would defeat the variables purpose.
44+
#[allow(clippy::needless_range_loop)]
45+
pub fn matrix_multiply(multiplier: &[Vec<i32>], multiplicand: &[Vec<i32>]) -> Vec<Vec<i32>> {
46+
// Multiply two matching matrices. The multiplier needs to have the same amount
47+
// of columns as the multiplicand has rows.
48+
let mut result: Vec<Vec<i32>> = vec![];
49+
let mut temp;
50+
// Using variable to compare lenghts of rows in multiplicand later
51+
let row_right_length = multiplicand[0].len();
52+
for row_left in 0..multiplier.len() {
53+
if multiplier[row_left].len() != multiplicand.len() {
54+
panic!("Matrix dimensions do not match");
55+
}
56+
result.push(vec![]);
57+
for column_right in 0..multiplicand[0].len() {
58+
temp = 0;
59+
for row_right in 0..multiplicand.len() {
60+
if row_right_length != multiplicand[row_right].len() {
61+
// If row is longer than a previous row cancel operation with error
62+
panic!("Matrix dimensions do not match");
63+
}
64+
temp += multiplier[row_left][row_right] * multiplicand[row_right][column_right];
65+
}
66+
result[row_left].push(temp);
67+
}
68+
}
69+
result
70+
}
71+
72+
pub fn matrix_transpose(matrix: &[Vec<i32>]) -> Vec<Vec<i32>> {
73+
// Transpose a matrix of any size
74+
let mut result: Vec<Vec<i32>> = vec![Vec::with_capacity(matrix.len()); matrix[0].len()];
75+
for row in matrix {
76+
for col in 0..row.len() {
77+
result[col].push(row[col]);
78+
}
79+
}
80+
result
81+
}
82+
83+
pub fn matrix_scalar_multiplication(matrix: &[Vec<i32>], scalar: i32) -> Vec<Vec<i32>> {
84+
// Multiply a matrix of any size with a scalar
85+
let mut result: Vec<Vec<i32>> = vec![Vec::with_capacity(matrix.len()); matrix[0].len()];
86+
for row in 0..matrix.len() {
87+
for column in 0..matrix[row].len() {
88+
result[row].push(scalar * matrix[row][column]);
89+
}
90+
}
91+
result
92+
}
93+
94+
#[cfg(test)]
95+
mod tests {
96+
use super::matrix_add;
97+
use super::matrix_multiply;
98+
use super::matrix_scalar_multiplication;
99+
use super::matrix_subtract;
100+
use super::matrix_transpose;
101+
102+
#[test]
103+
fn test_add() {
104+
let input0: Vec<Vec<i32>> = vec![vec![1, 0, 1], vec![0, 2, 0], vec![5, 0, 1]];
105+
let input1: Vec<Vec<i32>> = vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]];
106+
let input_wrong0: Vec<Vec<i32>> = vec![vec![1, 0, 0, 4], vec![0, 1, 0], vec![0, 0, 1]];
107+
let input_wrong1: Vec<Vec<i32>> =
108+
vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1], vec![1, 1, 1]];
109+
let input_wrong2: Vec<Vec<i32>> = vec![vec![]];
110+
let exp_result: Vec<Vec<i32>> = vec![vec![2, 0, 1], vec![0, 3, 0], vec![5, 0, 2]];
111+
assert_eq!(matrix_add(&input0, &input1), exp_result);
112+
let result0 = std::panic::catch_unwind(|| matrix_add(&input0, &input_wrong0));
113+
assert!(result0.is_err());
114+
let result1 = std::panic::catch_unwind(|| matrix_add(&input0, &input_wrong1));
115+
assert!(result1.is_err());
116+
let result2 = std::panic::catch_unwind(|| matrix_add(&input0, &input_wrong2));
117+
assert!(result2.is_err());
118+
}
119+
120+
#[test]
121+
fn test_subtract() {
122+
let input0: Vec<Vec<i32>> = vec![vec![1, 0, 1], vec![0, 2, 0], vec![5, 0, 1]];
123+
let input1: Vec<Vec<i32>> = vec![vec![1, 0, 0], vec![0, 1, 3], vec![0, 0, 1]];
124+
let input_wrong0: Vec<Vec<i32>> = vec![vec![1, 0, 0, 4], vec![0, 1, 0], vec![0, 0, 1]];
125+
let input_wrong1: Vec<Vec<i32>> =
126+
vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1], vec![1, 1, 1]];
127+
let input_wrong2: Vec<Vec<i32>> = vec![vec![]];
128+
let exp_result: Vec<Vec<i32>> = vec![vec![0, 0, 1], vec![0, 1, -3], vec![5, 0, 0]];
129+
assert_eq!(matrix_subtract(&input0, &input1), exp_result);
130+
let result0 = std::panic::catch_unwind(|| matrix_subtract(&input0, &input_wrong0));
131+
assert!(result0.is_err());
132+
let result1 = std::panic::catch_unwind(|| matrix_subtract(&input0, &input_wrong1));
133+
assert!(result1.is_err());
134+
let result2 = std::panic::catch_unwind(|| matrix_subtract(&input0, &input_wrong2));
135+
assert!(result2.is_err());
136+
}
137+
138+
#[test]
139+
fn test_multiply() {
140+
let input0: Vec<Vec<i32>> =
141+
vec![vec![1, 2, 3], vec![4, 2, 6], vec![3, 4, 1], vec![2, 4, 8]];
142+
let input1: Vec<Vec<i32>> = vec![vec![1, 3, 3, 2], vec![7, 6, 2, 1], vec![3, 4, 2, 1]];
143+
let input_wrong0: Vec<Vec<i32>> = vec![
144+
vec![1, 3, 3, 2, 4, 6, 6],
145+
vec![7, 6, 2, 1],
146+
vec![3, 4, 2, 1],
147+
];
148+
let input_wrong1: Vec<Vec<i32>> = vec![
149+
vec![1, 3, 3, 2],
150+
vec![7, 6, 2, 1],
151+
vec![3, 4, 2, 1],
152+
vec![3, 4, 2, 1],
153+
];
154+
let exp_result: Vec<Vec<i32>> = vec![
155+
vec![24, 27, 13, 7],
156+
vec![36, 48, 28, 16],
157+
vec![34, 37, 19, 11],
158+
vec![54, 62, 30, 16],
159+
];
160+
assert_eq!(matrix_multiply(&input0, &input1), exp_result);
161+
let result0 = std::panic::catch_unwind(|| matrix_multiply(&input0, &input_wrong0));
162+
assert!(result0.is_err());
163+
let result1 = std::panic::catch_unwind(|| matrix_multiply(&input0, &input_wrong1));
164+
assert!(result1.is_err());
165+
}
166+
167+
#[test]
168+
fn test_transpose() {
169+
let input0: Vec<Vec<i32>> = vec![vec![1, 0, 1], vec![0, 2, 0], vec![5, 0, 1]];
170+
let input1: Vec<Vec<i32>> = vec![vec![3, 4, 2], vec![0, 1, 3], vec![3, 1, 1]];
171+
let exp_result1: Vec<Vec<i32>> = vec![vec![1, 0, 5], vec![0, 2, 0], vec![1, 0, 1]];
172+
let exp_result2: Vec<Vec<i32>> = vec![vec![3, 0, 3], vec![4, 1, 1], vec![2, 3, 1]];
173+
assert_eq!(matrix_transpose(&input0), exp_result1);
174+
assert_eq!(matrix_transpose(&input1), exp_result2);
175+
}
176+
177+
#[test]
178+
fn test_matrix_scalar_multiplication() {
179+
let input0: Vec<Vec<i32>> = vec![vec![3, 2, 2], vec![0, 2, 0], vec![5, 4, 1]];
180+
let input1: Vec<Vec<i32>> = vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]];
181+
let exp_result1: Vec<Vec<i32>> = vec![vec![9, 6, 6], vec![0, 6, 0], vec![15, 12, 3]];
182+
let exp_result2: Vec<Vec<i32>> = vec![vec![3, 0, 0], vec![0, 3, 0], vec![0, 0, 3]];
183+
assert_eq!(matrix_scalar_multiplication(&input0, 3), exp_result1);
184+
assert_eq!(matrix_scalar_multiplication(&input1, 3), exp_result2);
185+
}
186+
}

src/math/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod greatest_common_divisor;
88
mod karatsuba_multiplication;
99
mod lcm_of_n_numbers;
1010
mod linear_sieve;
11+
mod matrix_ops;
1112
mod miller_rabin;
1213
mod nthprime;
1314
mod pascal_triangle;
@@ -38,6 +39,9 @@ pub use self::greatest_common_divisor::{
3839
pub use self::karatsuba_multiplication::multiply;
3940
pub use self::lcm_of_n_numbers::lcm;
4041
pub use self::linear_sieve::LinearSieve;
42+
pub use self::matrix_ops::{
43+
matrix_add, matrix_multiply, matrix_scalar_multiplication, matrix_subtract, matrix_transpose,
44+
};
4145
pub use self::miller_rabin::miller_rabin;
4246
pub use self::nthprime::nthprime;
4347
pub use self::pascal_triangle::pascal_triangle;

0 commit comments

Comments
 (0)