diff --git a/src/combinations_with_replacement.rs b/src/combinations_with_replacement.rs index 0fec9671a..0e3a20e0c 100644 --- a/src/combinations_with_replacement.rs +++ b/src/combinations_with_replacement.rs @@ -3,6 +3,7 @@ use std::fmt; use std::iter::FusedIterator; use super::lazy_buffer::LazyBuffer; +use crate::combinations::checked_binomial; /// An iterator to iterate through all the `n`-length combinations in an iterator, with replacement. /// @@ -24,7 +25,7 @@ where I: Iterator + fmt::Debug, I::Item: fmt::Debug + Clone, { - debug_fmt_fields!(Combinations, indices, pool, first); + debug_fmt_fields!(CombinationsWithReplacement, indices, pool, first); } impl CombinationsWithReplacement @@ -100,6 +101,19 @@ where None => None, } } + + fn size_hint(&self) -> (usize, Option) { + let (mut low, mut upp) = self.pool.size_hint(); + low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX); + upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices)); + (low, upp) + } + + fn count(self) -> usize { + let Self { indices, pool, first } = self; + let n = pool.count(); + remaining_for(n, first, &indices).unwrap() + } } impl FusedIterator for CombinationsWithReplacement @@ -107,3 +121,47 @@ where I: Iterator, I::Item: Clone, {} + +/// For a given size `n`, return the count of remaining combinations with replacement or None if it would overflow. +fn remaining_for(n: usize, first: bool, indices: &[usize]) -> Option { + // With a "stars and bars" representation, choose k values with replacement from n values is + // like choosing k out of k + n − 1 positions (hence binomial(k + n - 1, k) possibilities) + // to place k stars and therefore n - 1 bars. + // Example (n=4, k=6): ***|*||** represents [0,0,0,1,3,3]. + let count = |n: usize, k: usize| { + let positions = if n == 0 { k.saturating_sub(1) } else { (n - 1).checked_add(k)? }; + checked_binomial(positions, k) + }; + let k = indices.len(); + if first { + count(n, k) + } else { + // The algorithm is similar to the one for combinations *without replacement*, + // except we choose values *with replacement* and indices are *non-strictly* monotonically sorted. + + // The combinations generated after the current one can be counted by counting as follows: + // - The subsequent combinations that differ in indices[0]: + // If subsequent combinations differ in indices[0], then their value for indices[0] + // must be at least 1 greater than the current indices[0]. + // As indices is monotonically sorted, this means we can effectively choose k values with + // replacement from (n - 1 - indices[0]), leading to count(n - 1 - indices[0], k) possibilities. + // - The subsequent combinations with same indices[0], but differing indices[1]: + // Here we can choose k - 1 values with replacement from (n - 1 - indices[1]) values, + // leading to count(n - 1 - indices[1], k - 1) possibilities. + // - (...) + // - The subsequent combinations with same indices[0..=i], but differing indices[i]: + // Here we can choose k - i values with replacement from (n - 1 - indices[i]) values: count(n - 1 - indices[i], k - i). + // Since subsequent combinations can in any index, we must sum up the aforementioned binomial coefficients. + + // Below, `n0` resembles indices[i]. + indices + .iter() + .enumerate() + // TODO: Once the MSRV hits 1.37.0, we can sum options instead: + // .map(|(i, n0)| count(n - 1 - *n0, k - i)) + // .sum() + .fold(Some(0), |sum, (i, n0)| { + sum.and_then(|s| s.checked_add(count(n - 1 - *n0, k - i)?)) + }) + } +} diff --git a/tests/test_std.rs b/tests/test_std.rs index 8ea992183..94918598a 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -1019,6 +1019,28 @@ fn combinations_with_replacement() { ); } +#[test] +fn combinations_with_replacement_range_count() { + for n in 0..=7 { + for k in 0..=7 { + let len = binomial(usize::saturating_sub(n + k, 1), k); + let mut it = (0..n).combinations_with_replacement(k); + assert_eq!(len, it.clone().count()); + assert_eq!(len, it.size_hint().0); + assert_eq!(Some(len), it.size_hint().1); + for count in (0..len).rev() { + let elem = it.next(); + assert!(elem.is_some()); + assert_eq!(count, it.clone().count()); + assert_eq!(count, it.size_hint().0); + assert_eq!(Some(count), it.size_hint().1); + } + let should_be_none = it.next(); + assert!(should_be_none.is_none()); + } + } +} + #[test] fn powerset() { it::assert_equal((0..0).powerset(), vec![vec![]]);