diff --git a/src/permutations.rs b/src/permutations.rs index a8885a411..21a84c401 100644 --- a/src/permutations.rs +++ b/src/permutations.rs @@ -3,6 +3,7 @@ use std::fmt; use std::iter::once; use super::lazy_buffer::LazyBuffer; +use crate::size_hint::{self, SizeHint}; /// An iterator adaptor that iterates through all the `k`-permutations of the /// elements from an iterator. @@ -47,11 +48,6 @@ enum CompleteState { } } -enum CompleteStateRemaining { - Known(usize), - Overflow, -} - impl fmt::Debug for Permutations where I: Iterator + fmt::Debug, I::Item: fmt::Debug, @@ -72,14 +68,8 @@ pub fn permutations(iter: I, k: usize) -> Permutations { }; } - let mut enough_vals = true; - - while vals.len() < k { - if !vals.get_next() { - enough_vals = false; - break; - } - } + vals.prefill(k); + let enough_vals = vals.len() == k; let state = if enough_vals { PermutationState::StartUnknownLen { k } @@ -123,12 +113,7 @@ where fn count(self) -> usize { fn from_complete(complete_state: CompleteState) -> usize { - match complete_state.remaining() { - CompleteStateRemaining::Known(count) => count, - CompleteStateRemaining::Overflow => { - panic!("Iterator count greater than usize::MAX"); - } - } + complete_state.remaining().expect("Iterator count greater than usize::MAX") } let Permutations { vals, state } = self; @@ -151,13 +136,23 @@ where } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> SizeHint { + let at_start = |k| { + // At the beginning, there are `n!/(n-k)!` items to come (see `remaining`) but `n` might be unknown. + let (mut low, mut upp) = self.vals.size_hint(); + low = CompleteState::Start { n: low, k }.remaining().unwrap_or(usize::MAX); + upp = upp.and_then(|n| CompleteState::Start { n, k }.remaining()); + (low, upp) + }; match self.state { - PermutationState::StartUnknownLen { .. } | - PermutationState::OngoingUnknownLen { .. } => (0, None), // TODO can we improve this lower bound? + PermutationState::StartUnknownLen { k } => at_start(k), + PermutationState::OngoingUnknownLen { k, min_n } => { + // Same as `StartUnknownLen` minus the previously generated items. + size_hint::sub_scalar(at_start(k), min_n - k + 1) + } PermutationState::Complete(ref state) => match state.remaining() { - CompleteStateRemaining::Known(count) => (count, Some(count)), - CompleteStateRemaining::Overflow => (::std::usize::MAX, None) + Some(count) => (count, Some(count)), + None => (::std::usize::MAX, None) } PermutationState::Empty => (0, Some(0)) } @@ -238,39 +233,27 @@ impl CompleteState { } } - fn remaining(&self) -> CompleteStateRemaining { - use self::CompleteStateRemaining::{Known, Overflow}; - + /// Returns the count of remaining permutations, or None if it would overflow. + fn remaining(&self) -> Option { match *self { CompleteState::Start { n, k } => { if n < k { - return Known(0); + return Some(0); } - - let count: Option = (n - k + 1..n + 1).fold(Some(1), |acc, i| { + (n - k + 1..=n).fold(Some(1), |acc, i| { acc.and_then(|acc| acc.checked_mul(i)) - }); - - match count { - Some(count) => Known(count), - None => Overflow - } + }) } CompleteState::Ongoing { ref indices, ref cycles } => { let mut count: usize = 0; for (i, &c) in cycles.iter().enumerate() { let radix = indices.len() - i; - let next_count = count.checked_mul(radix) - .and_then(|count| count.checked_add(c)); - - count = match next_count { - Some(count) => count, - None => { return Overflow; } - }; + count = count.checked_mul(radix) + .and_then(|count| count.checked_add(c))?; } - Known(count) + Some(count) } } } diff --git a/src/size_hint.rs b/src/size_hint.rs index f7278aec9..76ccd13a2 100644 --- a/src/size_hint.rs +++ b/src/size_hint.rs @@ -30,7 +30,6 @@ pub fn add_scalar(sh: SizeHint, x: usize) -> SizeHint { /// Subtract `x` correctly from a `SizeHint`. #[inline] -#[allow(dead_code)] pub fn sub_scalar(sh: SizeHint, x: usize) -> SizeHint { let (mut low, mut hi) = sh; low = low.saturating_sub(x); diff --git a/tests/test_std.rs b/tests/test_std.rs index 8ea992183..77d210fb2 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -986,6 +986,44 @@ fn permutations_zero() { it::assert_equal((0..0).permutations(0), vec![vec![]]); } +#[test] +fn permutations_range_count() { + for n in 0..=7 { + for k in 0..=7 { + let len = if k <= n { + (n - k + 1..=n).product() + } else { + 0 + }; + let mut it = (0..n).permutations(k); + assert_eq!(len, it.clone().count()); + assert_eq!(len, it.size_hint().0); + assert_eq!(Some(len), it.size_hint().1); + for count in (0..len).rev() { + let elem = it.next(); + assert!(elem.is_some()); + assert_eq!(count, it.clone().count()); + assert_eq!(count, it.size_hint().0); + assert_eq!(Some(count), it.size_hint().1); + } + let should_be_none = it.next(); + assert!(should_be_none.is_none()); + } + } +} + +#[test] +fn permutations_overflowed_size_hints() { + let mut it = std::iter::repeat(()).permutations(2); + assert_eq!(it.size_hint().0, usize::MAX); + assert_eq!(it.size_hint().1, None); + for nb_generated in 1..=1000 { + it.next(); + assert!(it.size_hint().0 >= usize::MAX - nb_generated); + assert_eq!(it.size_hint().1, None); + } +} + #[test] fn combinations_with_replacement() { // Pool smaller than n