From b97707cdc2fe2e90641463b1f68076d29b9a7b91 Mon Sep 17 00:00:00 2001 From: SabrinaJewson Date: Sun, 22 Sep 2024 14:04:25 +0100 Subject: [PATCH 1/2] feat: support sampling integers from discrete distributions --- src/distribution/bernoulli.rs | 10 ++++++- src/distribution/binomial.rs | 16 +++++++--- src/distribution/categorical.rs | 26 +++++++++++----- src/distribution/discrete_uniform.rs | 10 ++++++- src/distribution/geometric.rs | 23 ++++++++++----- src/distribution/hypergeometric.rs | 16 +++++++--- src/distribution/multinomial.rs | 44 +++++++++++++++++++++------- src/distribution/poisson.rs | 14 +++++++++ 8 files changed, 125 insertions(+), 34 deletions(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index 521fa7ba..d059fb32 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -84,11 +84,19 @@ impl std::fmt::Display for Bernoulli { } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for Bernoulli { + fn sample(&self, rng: &mut R) -> bool { + rng.gen_bool(self.p()) + } +} + #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Bernoulli { fn sample(&self, rng: &mut R) -> f64 { - rng.gen_bool(self.p()) as u8 as f64 + f64::from(rng.sample::(self)) } } diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index dca3754c..1c5e2be2 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -111,12 +111,12 @@ impl std::fmt::Display for Binomial { #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -impl ::rand::distributions::Distribution for Binomial { - fn sample(&self, rng: &mut R) -> f64 { - (0..self.n).fold(0.0, |acc, _| { +impl ::rand::distributions::Distribution for Binomial { + fn sample(&self, rng: &mut R) -> u64 { + (0..self.n).fold(0, |acc, _| { let n: f64 = rng.gen(); if n < self.p { - acc + 1.0 + acc + 1 } else { acc } @@ -124,6 +124,14 @@ impl ::rand::distributions::Distribution for Binomial { } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for Binomial { + fn sample(&self, rng: &mut R) -> f64 { + rng.sample::(self) as f64 + } +} + impl DiscreteCDF for Binomial { /// Calculates the cumulative distribution function for the /// binomial distribution at `x` diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index b63b1312..008e89f9 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -123,11 +123,27 @@ impl std::fmt::Display for Categorical { } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for Categorical { + fn sample(&self, rng: &mut R) -> usize { + sample_unchecked(rng, &self.cdf) + } +} + +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for Categorical { + fn sample(&self, rng: &mut R) -> u64 { + sample_unchecked(rng, &self.cdf) as u64 + } +} + #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Categorical { fn sample(&self, rng: &mut R) -> f64 { - sample_unchecked(rng, &self.cdf) + sample_unchecked(rng, &self.cdf) as f64 } } @@ -325,13 +341,9 @@ impl Discrete for Categorical { /// without doing any bounds checking #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> f64 { +pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> usize { let draw = rng.gen::() * cdf.last().unwrap(); - cdf.iter() - .enumerate() - .find(|(_, val)| **val >= draw) - .map(|(i, _)| i) - .unwrap() as f64 + cdf.iter().position(|val| *val >= draw).unwrap() } /// Computes the cdf from the given probability masses. Performs diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index 55ec1b2e..a9760808 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -74,11 +74,19 @@ impl std::fmt::Display for DiscreteUniform { } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for DiscreteUniform { + fn sample(&self, rng: &mut R) -> i64 { + rng.gen_range(self.min..=self.max) + } +} + #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for DiscreteUniform { fn sample(&self, rng: &mut R) -> f64 { - rng.gen_range(self.min..=self.max) as f64 + rng.sample::(self) as f64 } } diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index 81e7439d..4ec1b30b 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -92,19 +92,28 @@ impl std::fmt::Display for Geometric { #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -impl ::rand::distributions::Distribution for Geometric { - fn sample(&self, r: &mut R) -> f64 { - use ::rand::distributions::OpenClosed01; - +impl ::rand::distributions::Distribution for Geometric { + fn sample(&self, r: &mut R) -> u64 { if ulps_eq!(self.p, 1.0) { - 1.0 + 1 } else { - let x: f64 = r.sample(OpenClosed01); - x.log(1.0 - self.p).ceil() + let x: f64 = r.sample(::rand::distributions::OpenClosed01); + // This cast is safe, because the largest finite value this expression can take is when + // `x = 1.4e-45` and `1.0 - self.p = 0.9999999999999999`, in which case we get + // `930262250532780300`, which when casted to a `u64` is `930262250532780288`. + x.log(1.0 - self.p).ceil() as u64 } } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for Geometric { + fn sample(&self, r: &mut R) -> f64 { + r.sample::(self) as f64 + } +} + impl DiscreteCDF for Geometric { /// Calculates the cumulative distribution function for the geometric /// distribution at `x` diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 089f351b..a1aef30d 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -147,17 +147,17 @@ impl std::fmt::Display for Hypergeometric { #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -impl ::rand::distributions::Distribution for Hypergeometric { - fn sample(&self, rng: &mut R) -> f64 { +impl ::rand::distributions::Distribution for Hypergeometric { + fn sample(&self, rng: &mut R) -> u64 { let mut population = self.population as f64; let mut successes = self.successes as f64; let mut draws = self.draws; - let mut x = 0.0; + let mut x = 0; loop { let p = successes / population; let next: f64 = rng.gen(); if next < p { - x += 1.0; + x += 1; successes -= 1.0; } population -= 1.0; @@ -170,6 +170,14 @@ impl ::rand::distributions::Distribution for Hypergeometric { } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for Hypergeometric { + fn sample(&self, rng: &mut R) -> f64 { + rng.sample::(self) as f64 + } +} + impl DiscreteCDF for Hypergeometric { /// Calculates the cumulative distribution function for the hypergeometric /// distribution at `x` diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index dd36f704..0794d352 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -159,6 +159,19 @@ where } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { + sample_generic(self, rng) + } +} + #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution> for Multinomial @@ -167,17 +180,28 @@ where nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn sample(&self, rng: &mut R) -> OVector { - use nalgebra::Const; - - let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice()); - let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); - for _ in 0..self.n { - let i = super::categorical::sample_unchecked(rng, &p_cdf); - let el = res.get_mut(i as usize).unwrap(); - *el += 1.0; - } - res + sample_generic(self, rng) + } +} + +#[cfg(feature = "rand")] +fn sample_generic(dist: &Multinomial, rng: &mut R) -> OVector +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + R: ::rand::Rng + ?Sized, + T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign, +{ + use nalgebra::Const; + + let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice()); + let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>); + for _ in 0..dist.n { + let i = super::categorical::sample_unchecked(rng, &p_cdf); + res[i] += T::one(); } + res } impl MeanN> for Multinomial diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 45958176..1c0598c4 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -89,6 +89,19 @@ impl std::fmt::Display for Poisson { } } +#[cfg(feature = "rand")] +#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] +impl ::rand::distributions::Distribution for Poisson { + /// Generates one sample from the Poisson distribution either by + /// Knuth's method if lambda < 30.0 or Rejection method PA by + /// A. C. Atkinson from the Journal of the Royal Statistical Society + /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 + /// otherwise + fn sample(&self, rng: &mut R) -> u64 { + sample_unchecked(rng, self.lambda) as u64 + } +} + #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Poisson { @@ -279,6 +292,7 @@ impl Discrete for Poisson { -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x) } } + /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society From 3deb714f585316f61e0fc2ee92bac2292dcbb330 Mon Sep 17 00:00:00 2001 From: SabrinaJewson Date: Sun, 22 Sep 2024 17:46:40 +0100 Subject: [PATCH 2/2] fix: MSRV --- src/distribution/bernoulli.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index d059fb32..21e3e27f 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -96,7 +96,7 @@ impl ::rand::distributions::Distribution for Bernoulli { #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Bernoulli { fn sample(&self, rng: &mut R) -> f64 { - f64::from(rng.sample::(self)) + rng.sample::(self) as u8 as f64 } }