This question arose when I looked into the efficient implementation of the MRG32k3a PRNG on 32-bit processors without, or slow, support for double
computation. I am particularly interested in ARM, RISC-V, and GPUs. MRG32k3a is a very high-quality PRNG and therefore still in wide-spread use today although it dates to the late 1990s:
P. L'Ecuyer, "Good parameters and implementations for combined multiple recursive random number generators." Operations Research, Vol. 47, No. 1, Jan.-Feb. 1999, pp. 159-164
MRG32k3a combines two recursive sequences of the form (c0 ⋅ state0 - c1 ⋅ state1) mod m, where statei < m. The constants and state variables in MRG32k3a are all positive integers that fit into 32 bits, and all intermediate expressions in the computation are less than 253 in magnitude. This is by design, as the reference implementation uses IEEE-754 double
for storage and computation. The mathematical modulus mod differs from ISO-C's %
operator by always delivering a non-negative result. The first variant in the code below shows the slightly modernized reference implementation using double
.
In integer-only implementation of MRG32k3a 32-bit variables are used for state components, and intermediate computations are performed in 64-bit arithmetic. The modulo is easily computed via %
by ensuring that the dividend is non-negative: (c0 ⋅ state0 - c1 ⋅ state1) mod m = (c0 ⋅ state0 - c1 ⋅ state1 + c1 ⋅ m) % m. The computation % m
is expensive on 32-bit processors, usually resulting in a library call. This is easily addressed via standard division-by-constant optimizations, with a 64-bit multiply-high computation as the most expensive part (see the GENERIC_MOD=1
variant in the code below).
Even faster modulo computation is possible when the magnitude of the dividend is restricted and m = 2n-d, with small d. One sets lo = x % 2n, hi= x / 2^n, t = hi * d + lo. As long as t < 2 ⋅ m, x mod m = (t >= m) ? (t - m) : t. The desired condition holds when x < 2n+(n-ceil(log2(d+1))). This works well for the first recurrence used by MRG32k3a, with n=32 and d=209 which requires x < 256, which is trivially satisfied. But for the second recurrence n=32 and d = 22853, requiring x < 249. After applying the offset c1 ⋅ m to ensure a positive x, x can be as large as 8.15 ⋅ 1015 in this case, only very slightly less than 253 ≈ 9 ⋅ 1015.
I am currently addressing this by basing the offset added to ensure positive x
prior to computing x % m
on the value of the state variables, and this keeps x < 244. But as one can see from the relevant code lines extracted below, this is a fairly expensive approach which includes a 32-bit division (with constant divisor and thus optimizable, but still incurring undesirable cost).
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
Are there alternative and less costly mitigation strategies that lead to a more efficient modulo computation for the second recurrence used by MRG32k3a?
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>
#define BUILTIN_64BIT (0)
#define GENERIC_MOD (0) // applies ony when BUILTIN_64BIT == 0
static double MRG32k3a_s10, MRG32k3a_s11, MRG32k3a_s12;
static double MRG32k3a_s20, MRG32k3a_s21, MRG32k3a_s22;
/* SIMD vectorized by Clang with -ffp-model=precise on x86-84 and AArch64
SIMD vectorized by Intel compiler with -fp-model=precise -march=core-avx2
*/
double MRG32k3a (void)
{
const double norm = 2.328306549295728e-10;
const double m1 = 4294967087.0;
const double m2 = 4294944443.0;
const double a12 = 1403580.0;
const double a13n = 810728.0;
const double a21 = 527612.0;
const double a23n = 1370589.0;
double k, p1, p2;
/* Component 1 */
p1 = a12 * MRG32k3a_s11 - a13n * MRG32k3a_s10;
k = floor (p1 / m1);
p1 -= k * m1;
MRG32k3a_s10 = MRG32k3a_s11; MRG32k3a_s11 = MRG32k3a_s12; MRG32k3a_s12 = p1;
/* Component 2 */
p2 = a21 * MRG32k3a_s22 - a23n * MRG32k3a_s20;
k = floor (p2 / m2);
p2 -= k * m2;
MRG32k3a_s20 = MRG32k3a_s21; MRG32k3a_s21 = MRG32k3a_s22; MRG32k3a_s22 = p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
static uint32_t MRG32k3a_s10i, MRG32k3a_s11i, MRG32k3a_s12i;
static uint32_t MRG32k3a_s20i, MRG32k3a_s21i, MRG32k3a_s22i;
#if BUILTIN_64BIT
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
uint64_t prod;
uint32_t p1, p2;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
p1 = (uint32_t)(prod % m1);
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
p2 = (uint32_t)(prod % m2);
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#elif GENERIC_MOD
uint64_t umul64hi (uint64_t a, uint64_t b)
{
uint32_t alo = (uint32_t)a;
uint32_t ahi = (uint32_t)(a >> 32);
uint32_t blo = (uint32_t)b;
uint32_t bhi = (uint32_t)(b >> 32);
uint64_t p0 = (uint64_t)alo * blo;
uint64_t p1 = (uint64_t)alo * bhi;
uint64_t p2 = (uint64_t)ahi * blo;
uint64_t p3 = (uint64_t)ahi * bhi;
return (p1 >> 32) + (((p0 >> 32) + (uint64_t)(uint32_t)p1 + p2) >> 32) + p3;
}
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
const uint32_t neg_m1 = 0 - m1; // 209
const uint32_t neg_m2 = 0 - m2; // 22853
const uint64_t magic_mul_m1 = 0x8000006880005551ull;
const uint64_t magic_mul_m2 = 0x4000165147c845ddull;
const uint32_t shft_m1 = 31;
const uint32_t shft_m2 = 30;
uint64_t prod;
uint32_t p1, p2;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
p1 = (uint32_t)((umul64hi (prod, magic_mul_m1) >> shft_m1) * neg_m1 + prod);
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
p2 = (uint32_t)((umul64hi (prod, magic_mul_m2) >> shft_m2) * neg_m2 + prod);
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#else // !BUILTIN_64BIT && !GENERIC_MOD --> special fast modulo computation
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
const uint32_t neg_m1 = 0 - m1; // 209
const uint32_t neg_m2 = 0 - m2; // 22853
uint64_t prod;
uint32_t p1, p2, prod_lo, prod_hi, adj;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive
// ! special modulo computation: prod must be < 2**56 !
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32);
p1 = prod_hi * neg_m1 + prod_lo;
if ((p1 >= m1) || (p1 < prod_lo)) p1 += neg_m1;
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
// ! special modulo computation: prod must be < 2**49 !
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32);
p2 = prod_hi * neg_m2 + prod_lo;
if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#endif // BUILTIN_64BIT
/*
http://www.burtleburtle.net/bob/hash/doobs.html
By Bob Jenkins, 1996. bob_jenkins@burtleburtle.net. You may use this
code any way you wish, private, educational, or commercial. It's free.
*/
#define mix(a,b,c) \
(a -= b, a -= c, a ^= (c>>13), \
b -= c, b -= a, b ^= (a<<8), \
c -= a, c -= b, c ^= (b>>13), \
a -= b, a -= c, a ^= (c>>12), \
b -= c, b -= a, b ^= (a<<16), \
c -= a, c -= b, c ^= (b>>5), \
a -= b, a -= c, a ^= (c>>3), \
b -= c, b -= a, b ^= (a<<10), \
c -= a, c -= b, c ^= (b>>15))
int main (void)
{
uint32_t m1 = 4294967087u;
uint32_t m2 = 4294944443u;
uint32_t a, b, c;
a = 3141592654u;
b = 2718281828u;
c = 10; MRG32k3a_s10 = MRG32k3a_s10i = (1u << 10) | (mix (a, b, c) % m1);
c = 11; MRG32k3a_s11 = MRG32k3a_s11i = (1u << 11) | (mix (a, b, c) % m1);
c = 12; MRG32k3a_s12 = MRG32k3a_s12i = (1u << 12) | (mix (a, b, c) % m1);
c = 20; MRG32k3a_s20 = MRG32k3a_s20i = (1u << 20) | (mix (a, b, c) % m2);
c = 21; MRG32k3a_s21 = MRG32k3a_s21i = (1u << 21) | (mix (a, b, c) % m2);
c = 22; MRG32k3a_s22 = MRG32k3a_s22i = (1u << 22) | (mix (a, b, c) % m2);
double res, ref;
uint64_t count = 0;
do {
res = MRG32k3a_i();
ref = MRG32k3a();
if (res != ref) {
printf("\ncount=%llu ref=%23.16e res=%23.16e\n", count, res, ref);
return EXIT_FAILURE;
}
count++;
if ((count & 0xfffffff) == 0) printf ("\rcount = %llu ", count);
} while (ref != 0);
return EXIT_SUCCESS;
}
I haven't tested or benchmarked this, but if I've understood what you're doing correctly, I think another option is to add a fixed multiple of m2 and have two rounds of reduction.
prod = ((uint64_t)a21) * MRG32k3a_s22i + ((uint64_t)m2 << 22) - ((uint64_t)a23n) * MRG32k3a_s20i; // 55 bits
Then you can omit the following two lines.
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
and then do
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32); // 23 bits
prod = (uint64_t)prod_hi * neg_m2 + prod_lo; // 39 bits
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32);
p2 = prod_hi * neg_m2 + prod_lo;
if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;
You also might want to change the Component 1 calculation to something like
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i + ((uint64_t)a13n) * (m1 - MRG32k3a_s10i); // 54 bits
instead of the 2 steps
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive
so as to not rely on the compiler being clever enough to realise it only needs two multiplications. I think this is valid as 0 <= MRG32k3a_s1xi < m1 right?