diff --git a/src/combinations.rs b/src/combinations.rs index 61833cecd..a3f4c9550 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -50,6 +50,11 @@ impl Iterator for Combinations { type Item = Vec; fn next(&mut self) -> Option { + if self.first && self.k == 0 { + self.first = false; + return Some(Vec::new()); + } + let mut pool_len = self.pool.len(); if self.pool.is_done() { if pool_len == 0 || self.k > pool_len { diff --git a/src/combinations_with_replacement.rs b/src/combinations_with_replacement.rs index 499ccf70d..a7bd6a581 100644 --- a/src/combinations_with_replacement.rs +++ b/src/combinations_with_replacement.rs @@ -65,13 +65,12 @@ where fn next(&mut self) -> Option { // If this is the first iteration, return early if self.first { - // In empty edge cases, stop iterating immediately - return if self.k == 0 || self.pool.is_done() { - None - // Otherwise, yield the initial state - } else { - self.first = false; - Some(self.current()) + self.first = false; + // Handle corner cases of ((N,0)), ((0,N)), and ((0,0)) + return match (self.pool.is_done(), self.k == 0) { + (_, true) => Some(Vec::new()), // ((0/N, 0)) = 1 + (true, false) => None, // ((0, N)) = 0 + _ => Some(self.current()) // Otherwise, yield the initial state }; } diff --git a/tests/test_std.rs b/tests/test_std.rs index 4f48c55cf..adcdec82c 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -599,6 +599,13 @@ fn combinations_of_too_short() { #[test] fn combinations_zero() { it::assert_equal((1..3).combinations(0), vec![vec![]]); + it::assert_equal((0..0).combinations(0), vec![vec![]]); +} + +#[test] +fn permutations_zero() { + it::assert_equal((1..3).permutations(0), vec![vec![]]); + it::assert_equal((0..0).permutations(0), vec![vec![]]); } #[test] @@ -620,7 +627,12 @@ fn combinations_with_replacement() { // Zero size it::assert_equal( (0..3).combinations_with_replacement(0), - >>::new(), + vec![vec![]], + ); + // Zero size on empty pool + it::assert_equal( + (0..0).combinations_with_replacement(0), + vec![vec![]], ); // Empty pool it::assert_equal(