Skip to content

Commit a77f984

Browse files
authored
Merge pull request #613 from OpenVicProject/weighted_sampling
Add weighted sampling
2 parents ac4fbed + 5de711f commit a77f984

3 files changed

Lines changed: 308 additions & 0 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <limits>
5+
6+
#include "openvic-simulation/types/fixed_point/FixedPoint.hpp"
7+
8+
namespace OpenVic {
9+
size_t sample_weighted_index(
10+
const uint32_t random_value,
11+
const std::span<fixed_point_t> positive_weights,
12+
const fixed_point_t weights_sum
13+
) {
14+
if (positive_weights.empty() || weights_sum <= 0) {
15+
return -1;
16+
}
17+
18+
constexpr uint32_t max = std::numeric_limits<uint32_t>().max();
19+
const fixed_point_t scaled_random = weights_sum * random_value;
20+
21+
fixed_point_t cumulative_weight = 0;
22+
for (size_t i = 0; i < positive_weights.size(); ++i) {
23+
cumulative_weight += positive_weights[i];
24+
25+
if (cumulative_weight * max >= scaled_random) {
26+
return static_cast<size_t>(i);
27+
}
28+
}
29+
30+
return static_cast<size_t>(positive_weights.size() - 1);
31+
}
32+
}

tests/src/utility/ExtendedMath.hpp

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <limits>
5+
6+
namespace OpenVic::testing {
7+
/** @brief Structure to hold a 96-bit signed integer product.
8+
* The product of a signed 64-bit number and an unsigned 32-bit number
9+
* requires up to 96 bits (2^63 * 2^32 = 2^95).
10+
*
11+
* The result is stored as a magnitude across two 64-bit words, plus a sign flag.
12+
* The mathematical value is: (-1)^is_negative * (high * 2^64 + low)
13+
*/
14+
struct Int96Product {
15+
uint64_t low; ///< The lower 64 bits of the product's magnitude.
16+
uint64_t high; ///< The upper 32 bits (stored in a 64-bit word) of the product's magnitude.
17+
bool is_negative; ///< True if the overall product is negative.
18+
};
19+
20+
/** @brief Multiplies an int64_t by a uint32_t, returning the exact 96-bit result.
21+
* This function performs the multiplication using standard C++ 64-bit arithmetic,
22+
* preventing intermediate overflow and storing the full result in a 96-bit structure.
23+
*
24+
* @param a The signed 64-bit integer factor (int64_t).
25+
* @param b The unsigned 32-bit integer factor (uint32_t).
26+
* @return An Int96Product structure containing the full 96-bit signed result.
27+
*/
28+
[[nodiscard]] Int96Product portable_int64_mult_uint32_96bit(int64_t a, uint32_t b) {
29+
Int96Product result = {0, 0, false};
30+
31+
// 1. Handle zero factors
32+
if (a == 0 || b == 0) {
33+
return result;
34+
}
35+
36+
// 2. Determine Sign and Magnitude
37+
result.is_negative = (a < 0);
38+
39+
// Get the magnitude of 'a' as an unsigned 64-bit number.
40+
// Note: If a == INT64_MIN, -a is undefined in signed arithmetic.
41+
// Using (uint64_t)a converts it to 2's complement representation, and
42+
// applying the negative sign to the result of a * 1 (for INT64_MIN) yields the correct magnitude.
43+
uint64_t a_mag;
44+
if (a == std::numeric_limits<int64_t>::min()) {
45+
// Absolute value of INT64_MIN is 2^63, which requires a custom calculation
46+
// to avoid signed overflow/UB when finding the magnitude.
47+
// INT64_MIN is 0x8000000000000000. Its magnitude is 0x8000000000000000.
48+
// The cast to uint64_t yields this magnitude directly.
49+
a_mag = (uint64_t)std::numeric_limits<int64_t>::min();
50+
} else {
51+
a_mag = (a < 0) ? (uint64_t)-a : (uint64_t)a;
52+
}
53+
54+
uint64_t b_mag = b; // b is already unsigned and positive
55+
56+
// 3. Perform 64x32 -> 96-bit multiplication using split-word technique
57+
58+
// Masks and constants
59+
const uint64_t MASK_32 = 0xFFFFFFFFULL;
60+
61+
// Split a_mag into 32-bit parts: a_mag = a_H * 2^32 + a_L
62+
uint64_t a_L = a_mag & MASK_32;
63+
uint64_t a_H = a_mag >> 32;
64+
65+
// Partial product P_L = a_L * b_mag (32x32 -> 64 bits)
66+
uint64_t P_L = a_L * b_mag;
67+
68+
// Partial product P_H = a_H * b_mag (32x32 -> 64 bits)
69+
uint64_t P_H = a_H * b_mag;
70+
71+
// 4. Combine Products
72+
73+
// The low word of the 96-bit result is simply P_L
74+
result.low = P_L;
75+
76+
// The high word (P_H * 2^32) contribution needs to be added to the carry from P_L.
77+
// P_L_carry is the upper 32 bits of P_L, which carries into the high word.
78+
uint64_t P_L_carry = P_L >> 32;
79+
80+
// The Middle Word (bits 32-95) is P_H + P_L_carry
81+
// This is the total coefficient for 2^32.
82+
uint64_t middle_word = P_H + P_L_carry;
83+
84+
// The upper half of the 96-bit result (result.high, bits 64-95)
85+
// is the upper 32 bits of the middle_word.
86+
// This is the correct coefficient for 2^64.
87+
result.high = middle_word >> 32;
88+
89+
// Note: Since a_H * b_mag is max (2^32-1) * (2^32-1) < 2^64, and P_L_carry < 2^32,
90+
// the sum result.high is safely stored in a uint64_t. The highest bit (2^95)
91+
// of the product will be contained in the result.high word.
92+
93+
return result;
94+
}
95+
96+
/**
97+
* @brief Structure to hold the result of the division operation.
98+
*/
99+
struct Int96DivisionResult {
100+
int64_t quotient; ///< The quotient (clamped to INT64_MAX/MIN if overflow occurs).
101+
int64_t remainder; ///< The remainder, with the same sign as the dividend.
102+
bool quotient_overflow; ///< True if the mathematical quotient exceeds int64_t limits.
103+
};
104+
105+
/**
106+
* @brief Divides a 96-bit signed integer (Int96Product) by an int64_t, checking for quotient overflow.
107+
* * This is implemented using a portable iterative long division algorithm.
108+
*
109+
* @param dividend The 96-bit signed number.
110+
* @param divisor The signed 64-bit divisor.
111+
* @return An IntDivisionResult containing the quotient, remainder, and overflow status.
112+
*/
113+
Int96DivisionResult portable_int96_div_int64(const Int96Product& dividend, int64_t divisor) {
114+
Int96DivisionResult result = {0, 0, false};
115+
116+
// 1. Handle Divisor Zero
117+
if (divisor == 0) {
118+
result.quotient_overflow = true;
119+
result.quotient = std::numeric_limits<int64_t>::max(); // Return clamped value for safety
120+
result.remainder = 0;
121+
return result;
122+
}
123+
124+
// 2. Handle Dividend Zero
125+
if (dividend.low == 0 && dividend.high == 0) {
126+
return result;
127+
}
128+
129+
// 3. Determine Signs and Magnitudes
130+
const bool result_negative = (dividend.is_negative != (divisor < 0));
131+
132+
uint64_t d_mag;
133+
if (divisor == std::numeric_limits<int64_t>::min()) {
134+
// INT64_MIN magnitude is 2^63 (0x8000000000000000ULL)
135+
d_mag = 1ULL << 63;
136+
} else {
137+
// For other negative numbers, -divisor is safe. For positive, divisor is fine.
138+
d_mag = (divisor < 0) ? (uint64_t)-divisor : (uint64_t)divisor;
139+
}
140+
141+
// 4. Perform 96-bit / 64-bit Division on Magnitudes (Iterative Long Division)
142+
143+
uint64_t quotient_mag = 0;
144+
uint64_t remainder_mag = 0; // Remainder is 64-bit and accumulates dividend bits
145+
146+
// Iterate through all 96 bits of the dividend magnitude (from bit 95 down to 0)
147+
for (int i = 95; i >= 0; --i) {
148+
remainder_mag <<= 1;
149+
150+
// Bring in the next dividend bit
151+
if (i >= 64) {
152+
// Bit is in dividend.high at position i-64 (Max 32 bits)
153+
if ((dividend.high >> (i - 64)) & 1) {
154+
remainder_mag |= 1;
155+
}
156+
} else {
157+
// Bit is in dividend.low at position i (64 bits)
158+
if ((dividend.low >> i) & 1) {
159+
remainder_mag |= 1;
160+
}
161+
}
162+
163+
// Try to subtract divisor magnitude from the current remainder
164+
if (remainder_mag >= d_mag) {
165+
remainder_mag -= d_mag;
166+
167+
if (i >= 64) {
168+
// If a quotient bit is set at or above index 64, the mathematical quotient
169+
// is >= 2^64, which is an overflow for int64_t.
170+
result.quotient_overflow = true;
171+
// We must continue the loop to calculate the correct final remainder.
172+
} else {
173+
// Store the quotient bit $0 \le i \le 63$.
174+
quotient_mag |= (1ULL << i);
175+
}
176+
}
177+
}
178+
179+
// 5. Final Overflow Check and Sign Application
180+
const uint64_t INT64_MAX_MAG = ~(1ULL << 63);
181+
// The magnitude of INT64_MIN is 2^63, which is 0x8000000000000000ULL
182+
const uint64_t INT64_MIN_MAG = 1ULL << 63;
183+
184+
if (result_negative) {
185+
// If overflow was detected earlier OR if the 64-bit quotient magnitude exceeds 2^63 (INT64_MIN magnitude)
186+
if (result.quotient_overflow || (quotient_mag > INT64_MIN_MAG)) {
187+
result.quotient_overflow = true;
188+
result.quotient = std::numeric_limits<int64_t>::min();
189+
} else {
190+
// Check for the exact INT64_MIN case (magnitude 2^63)
191+
if (quotient_mag == INT64_MIN_MAG) {
192+
result.quotient = std::numeric_limits<int64_t>::min();
193+
} else {
194+
result.quotient = -(int64_t)quotient_mag;
195+
}
196+
}
197+
} else { // Positive result
198+
if (result.quotient_overflow || (quotient_mag > INT64_MAX_MAG)) {
199+
result.quotient_overflow = true;
200+
result.quotient = std::numeric_limits<int64_t>::max();
201+
} else {
202+
result.quotient = (int64_t)quotient_mag;
203+
}
204+
}
205+
206+
// Apply sign to remainder (must have the same sign as the dividend).
207+
if (dividend.is_negative && remainder_mag != 0) {
208+
result.remainder = -(int64_t)remainder_mag;
209+
} else {
210+
result.remainder = (int64_t)remainder_mag;
211+
}
212+
213+
return result;
214+
}
215+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "openvic-simulation/types/fixed_point/FixedPoint.hpp"
2+
#include "openvic-simulation/utility/WeightedSampling.hpp"
3+
4+
#include "utility/ExtendedMath.hpp"
5+
6+
#include <cstdint>
7+
#include <limits>
8+
9+
#include <snitch/snitch_macros_check.hpp>
10+
#include <snitch/snitch_macros_test_case.hpp>
11+
12+
using namespace OpenVic;
13+
using namespace OpenVic::testing;
14+
15+
constexpr uint32_t max_random_value = std::numeric_limits<uint32_t>().max();
16+
17+
TEST_CASE("WeightedSampling limits", "[WeightedSampling]") {
18+
const fixed_point_t weights_sum = fixed_point_t::usable_max;
19+
constexpr size_t perfect_divisor = 257; //4294967295 is perfectly divisible by 257
20+
std::array<fixed_point_t, perfect_divisor> weights {};
21+
constexpr uint32_t step_size = max_random_value / weights.size();
22+
weights.fill(weights_sum / weights.size());
23+
for (size_t i = 0; i < weights.size(); ++i) {
24+
CHECK(sample_weighted_index(
25+
i * step_size,
26+
weights,
27+
weights_sum
28+
) == i);
29+
}
30+
}
31+
TEST_CASE("WeightedSampling weights", "[WeightedSampling]") {
32+
constexpr size_t max_length = 255; //weights_sum = (size^2+size)/2, this needs to be <= fixed_point_t::usable_max
33+
std::array<fixed_point_t, max_length> weights {};
34+
35+
fixed_point_t weights_sum = 0;
36+
for (size_t i = 0; i < weights.size(); ++i) {
37+
const fixed_point_t weight = 1+i;
38+
weights_sum += weight;
39+
weights[i] = weight;
40+
}
41+
42+
assert(weights_sum > 0 && weights_sum <= fixed_point_t::usable_max);
43+
44+
fixed_point_t cumulative_weight = 0;
45+
for (size_t i = 0; i < weights.size(); ++i) {
46+
cumulative_weight += weights[i];
47+
const Int96DivisionResult random_value = portable_int96_div_int64(
48+
portable_int64_mult_uint32_96bit(
49+
cumulative_weight.get_raw_value(),
50+
max_random_value
51+
),
52+
weights_sum.get_raw_value()
53+
);
54+
assert(!random_value.quotient_overflow);
55+
CHECK(sample_weighted_index(
56+
static_cast<uint32_t>(random_value.quotient),
57+
weights,
58+
weights_sum
59+
) == i);
60+
}
61+
}

0 commit comments

Comments
 (0)