calgorithmmicro-optimizationinteger-division

Optimized 53->32 bit modulo computation on 32-bit processors


This question arose when I looked into the efficient implementation of the MRG32k3a PRNG on 32-bit processors without, or slow, support for double computation. I am particularly interested in ARM, RISC-V, and GPUs. MRG32k3a is a very high-quality PRNG and therefore still in wide-spread use today although it dates to the late 1990s:

P. L'Ecuyer, "Good parameters and implementations for combined multiple recursive random number generators." Operations Research, Vol. 47, No. 1, Jan.-Feb. 1999, pp. 159-164

MRG32k3a combines two recursive sequences of the form (c0 ⋅ state0 - c1 ⋅ state1) mod m, where statei < m. The constants and state variables in MRG32k3a are all positive integers that fit into 32 bits, and all intermediate expressions in the computation are less than 253 in magnitude. This is by design, as the reference implementation uses IEEE-754 double for storage and computation. The mathematical modulus mod differs from ISO-C's % operator by always delivering a non-negative result. The first variant in the code below shows the slightly modernized reference implementation using double.

In integer-only implementation of MRG32k3a 32-bit variables are used for state components, and intermediate computations are performed in 64-bit arithmetic. The modulo is easily computed via % by ensuring that the dividend is non-negative: (c0 ⋅ state0 - c1 ⋅ state1) mod m = (c0 ⋅ state0 - c1 ⋅ state1 + c1 ⋅ m) % m. The computation % m is expensive on 32-bit processors, usually resulting in a library call. This is easily addressed via standard division-by-constant optimizations, with a 64-bit multiply-high computation as the most expensive part (see the GENERIC_MOD=1 variant in the code below).

Even faster modulo computation is possible when the magnitude of the dividend is restricted and m = 2n-d, with small d. One sets lo = x % 2n, hi= x / 2^n, t = hi * d + lo. As long as t < 2 ⋅ m, x mod m = (t >= m) ? (t - m) : t. The desired condition holds when x < 2n+(n-ceil(log2(d+1))). This works well for the first recurrence used by MRG32k3a, with n=32 and d=209 which requires x < 256, which is trivially satisfied. But for the second recurrence n=32 and d = 22853, requiring x < 249. After applying the offset c1 ⋅ m to ensure a positive x, x can be as large as 8.15 ⋅ 1015 in this case, only very slightly less than 253 ≈ 9 ⋅ 1015.

I am currently addressing this by basing the offset added to ensure positive x prior to computing x % m on the value of the state variables, and this keeps x < 244. But as one can see from the relevant code lines extracted below, this is a fairly expensive approach which includes a 32-bit division (with constant divisor and thus optimizable, but still incurring undesirable cost).

    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
    prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive

Are there alternative and less costly mitigation strategies that lead to a more efficient modulo computation for the second recurrence used by MRG32k3a?

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>

#define BUILTIN_64BIT (0)
#define GENERIC_MOD   (0)  // applies ony when BUILTIN_64BIT == 0

static double MRG32k3a_s10, MRG32k3a_s11, MRG32k3a_s12;
static double MRG32k3a_s20, MRG32k3a_s21, MRG32k3a_s22;

/* SIMD vectorized by Clang with -ffp-model=precise on x86-84 and AArch64
   SIMD vectorized by Intel compiler with -fp-model=precise -march=core-avx2
 */
double MRG32k3a (void)
{
    const double norm = 2.328306549295728e-10;
    const double m1 = 4294967087.0;
    const double m2 = 4294944443.0;
    const double a12 = 1403580.0;
    const double a13n = 810728.0;
    const double a21 = 527612.0;
    const double a23n = 1370589.0;
    double k, p1, p2;

    /* Component 1 */
    p1 = a12 * MRG32k3a_s11 - a13n * MRG32k3a_s10;
    k = floor (p1 / m1);  
    p1 -= k * m1;
    MRG32k3a_s10 = MRG32k3a_s11; MRG32k3a_s11 = MRG32k3a_s12; MRG32k3a_s12 = p1;
    /* Component 2 */
    p2 = a21 * MRG32k3a_s22 - a23n * MRG32k3a_s20;
    k = floor (p2 / m2);  
    p2 -= k * m2;
    MRG32k3a_s20 = MRG32k3a_s21; MRG32k3a_s21 = MRG32k3a_s22; MRG32k3a_s22 = p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}

static uint32_t MRG32k3a_s10i, MRG32k3a_s11i, MRG32k3a_s12i;
static uint32_t MRG32k3a_s20i, MRG32k3a_s21i, MRG32k3a_s22i;

#if BUILTIN_64BIT
double MRG32k3a_i (void)
{
    const double norm = 2.328306549295728e-10;
    const uint32_t m1 = 4294967087u;
    const uint32_t m2 = 4294944443u;
    const uint32_t a12 = 1403580u;
    const uint32_t a13n = 810728u;
    const uint32_t a21 = 527612u;
    const uint32_t a23n = 1370589u;
    uint64_t prod;
    uint32_t p1, p2;

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
    prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
    p1 = (uint32_t)(prod % m1);
    MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
    /* Component 2 */
    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
    p2 = (uint32_t)(prod % m2);
    MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#elif GENERIC_MOD
uint64_t umul64hi (uint64_t a, uint64_t b)
{
    uint32_t alo = (uint32_t)a;
    uint32_t ahi = (uint32_t)(a >> 32);
    uint32_t blo = (uint32_t)b;
    uint32_t bhi = (uint32_t)(b >> 32);
    uint64_t p0 = (uint64_t)alo * blo;
    uint64_t p1 = (uint64_t)alo * bhi;
    uint64_t p2 = (uint64_t)ahi * blo;
    uint64_t p3 = (uint64_t)ahi * bhi;
    return (p1 >> 32) + (((p0 >> 32) + (uint64_t)(uint32_t)p1 + p2) >> 32) + p3;
}

double MRG32k3a_i (void)
{
    const double norm = 2.328306549295728e-10;
    const uint32_t m1 = 4294967087u;
    const uint32_t m2 = 4294944443u;
    const uint32_t a12 = 1403580u;
    const uint32_t a13n = 810728u;
    const uint32_t a21 = 527612u;
    const uint32_t a23n = 1370589u;
    const uint32_t neg_m1 = 0 - m1; // 209
    const uint32_t neg_m2 = 0 - m2; // 22853
    const uint64_t magic_mul_m1 = 0x8000006880005551ull;
    const uint64_t magic_mul_m2 = 0x4000165147c845ddull;
    const uint32_t shft_m1 = 31;
    const uint32_t shft_m2 = 30;
    uint64_t prod;
    uint32_t p1, p2;

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
    prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
    p1 = (uint32_t)((umul64hi (prod, magic_mul_m1) >> shft_m1) * neg_m1 + prod);
    MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
    /* Component 2 */
    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
    p2 = (uint32_t)((umul64hi (prod, magic_mul_m2) >> shft_m2) * neg_m2 + prod);
    MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#else // !BUILTIN_64BIT && !GENERIC_MOD --> special fast modulo computation
double MRG32k3a_i (void)
{
    const double norm = 2.328306549295728e-10;
    const uint32_t m1 = 4294967087u;
    const uint32_t m2 = 4294944443u;
    const uint32_t a12 = 1403580u;
    const uint32_t a13n = 810728u;
    const uint32_t a21 = 527612u;
    const uint32_t a23n = 1370589u;
    const uint32_t neg_m1 = 0 - m1; // 209
    const uint32_t neg_m2 = 0 - m2; // 22853
    uint64_t prod;
    uint32_t p1, p2, prod_lo, prod_hi, adj;

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
    prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive
    // ! special modulo computation: prod must be < 2**56 !
    prod_lo = (uint32_t)prod;
    prod_hi = (uint32_t)(prod >> 32);
    p1 = prod_hi * neg_m1 + prod_lo;
    if ((p1 >= m1) || (p1 < prod_lo)) p1 += neg_m1;
    MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
    /* Component 2 */
    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
    prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
    // ! special modulo computation: prod must be < 2**49 !
    prod_lo = (uint32_t)prod;
    prod_hi = (uint32_t)(prod >> 32);
    p2 = prod_hi * neg_m2 + prod_lo;
    if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;
    MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#endif // BUILTIN_64BIT

/*
  http://www.burtleburtle.net/bob/hash/doobs.html
  By Bob Jenkins, 1996.  bob_jenkins@burtleburtle.net.  You may use this
  code any way you wish, private, educational, or commercial.  It's free.
*/
#define mix(a,b,c) \
    (a -= b, a -= c, a ^= (c>>13), \
     b -= c, b -= a, b ^= (a<<8),  \
     c -= a, c -= b, c ^= (b>>13), \
     a -= b, a -= c, a ^= (c>>12), \
     b -= c, b -= a, b ^= (a<<16), \
     c -= a, c -= b, c ^= (b>>5),  \
     a -= b, a -= c, a ^= (c>>3),  \
     b -= c, b -= a, b ^= (a<<10), \
     c -= a, c -= b, c ^= (b>>15))

int main (void)
{
    uint32_t m1 = 4294967087u;
    uint32_t m2 = 4294944443u;
    uint32_t a, b, c;
    a = 3141592654u;
    b = 2718281828u;
    c = 10; MRG32k3a_s10 = MRG32k3a_s10i = (1u << 10) | (mix (a, b, c) % m1);
    c = 11; MRG32k3a_s11 = MRG32k3a_s11i = (1u << 11) | (mix (a, b, c) % m1);
    c = 12; MRG32k3a_s12 = MRG32k3a_s12i = (1u << 12) | (mix (a, b, c) % m1);
    c = 20; MRG32k3a_s20 = MRG32k3a_s20i = (1u << 20) | (mix (a, b, c) % m2);
    c = 21; MRG32k3a_s21 = MRG32k3a_s21i = (1u << 21) | (mix (a, b, c) % m2);
    c = 22; MRG32k3a_s22 = MRG32k3a_s22i = (1u << 22) | (mix (a, b, c) % m2);
    
    double res, ref;
    uint64_t count = 0;
    do {
        res = MRG32k3a_i();
        ref = MRG32k3a();
        if (res != ref) {
            printf("\ncount=%llu  ref=%23.16e  res=%23.16e\n", count, res, ref);
            return EXIT_FAILURE;
        }
        count++;
        if ((count & 0xfffffff) == 0) printf ("\rcount = %llu ", count);
    } while (ref != 0);
    return EXIT_SUCCESS;
}

Solution

  • I haven't tested or benchmarked this, but if I've understood what you're doing correctly, I think another option is to add a fixed multiple of m2 and have two rounds of reduction.

    prod = ((uint64_t)a21) * MRG32k3a_s22i + ((uint64_t)m2 << 22) - ((uint64_t)a23n) * MRG32k3a_s20i; // 55 bits
    

    Then you can omit the following two lines.

    adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
    prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
    

    and then do

    prod_lo = (uint32_t)prod; 
    prod_hi = (uint32_t)(prod >> 32); // 23 bits
    prod = (uint64_t)prod_hi * neg_m2 + prod_lo; // 39 bits
    prod_lo = (uint32_t)prod; 
    prod_hi = (uint32_t)(prod >> 32); 
    p2 = prod_hi * neg_m2 + prod_lo; 
    if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;
    

    You also might want to change the Component 1 calculation to something like

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i + ((uint64_t)a13n) * (m1 - MRG32k3a_s10i); // 54 bits
    

    instead of the 2 steps

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
    prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive
    

    so as to not rely on the compiler being clever enough to realise it only needs two multiplications. I think this is valid as 0 <= MRG32k3a_s1xi < m1 right?