mathintegermoduloprngknuth

Computing (a*b) mod c quickly for c=2^N +-1


In 32 bit integer math, basic math operations of add and multiply are computed implicitly mod 2^32, meaning your results will be the lowest order bits of the add or multiply.

If you want to compute the result with a different modulus, you certainly could use any number of BigInt classes in different languages. And for values a,b,c < 2^32 you could compute the intermediate values in 64 bit long ints and use built in % operators to reduce to the right answe

But I've been told that there are special tricks for efficiently computing a*b mod C when C is of the form (2^N)-1 or (2^N)+1, that don't use 64 bit math or a BigInt library and are quite efficient, more so than an arbitrary modulus evaluation, and also properly compute cases which would normally overflow a 32 bit int if you were including the intermediate multiplication.

Unfortunately, despite hearing that such special cases have a fast evaluation method, I haven't actually found a description of the method. "Isn't that in Knuth?" "Isn't that somewhere on Wikipedia?" are the mumblings I've heard.

It apparently is a common technique in random number generators which are doing multiplies of a*b mod 2147483647, since 2147483647 is a prime number equal to 2^31 -1.

So I'll ask the experts. What's this clever special case multiply-with-mod method that I can't find any discussion of?


Solution

  • I think the trick is the following (I'm going to do it in base 10, because it's easier, but the principle should hold)

    Suppose you are multiplying a*b mod 10000-1, and

    a = 1234 = 12 * 100 + 34
    b = 5432 = 54 * 100 + 32
    

    now a*b = 12 * 54 * 10000 + 34 * 54 * 100 + 12 * 32 * 100 + 34 * 32

    12 * 54 * 10000 =  648 * 10000
    34 * 54 * 100   = 1836 * 100
    12 * 32 * 100   =  384 * 100
    34 * 32         = 1088
    

    Since x * 10000 ≡ x (mod 10000-1) [1], the first and last terms become 648+1088. The second and third terms are where the 'trick' come in. Note that:

    1836 = 18 * 100 + 36
    1836 * 100 ≡ 18 * 10000 + 3600 ≡ 3618 (mod 10000-1).
    

    This is essentially a circular shift. Giving the results of 648 + 3618 + 8403 + 1088. And also note that in all cases, the multiplied numbers are < 10000 (since a < 100 and b < 100), so this is calculable if you only could multiple 2 digit numbers together, and add them.

    In binary, it's going to work out similarly.

    Start with a and b, both are 32 bits. Suppose you want to multiply them mod 2^31 - 1, but you only have a 16 bit multiplier (giving 32 bits). The algorithm would be something like this:

     a = 0x12345678
     b = 0xfedbca98
     accumulator = 0
     for (x = 0; x < 32; x += 16)
         for (y = 0; y < 32; y += 16)
             // do the multiplication, 16-bit * 16-bit = 32-bit
             temp = ((a >> x) & 0xFFFF) * ((b >> y) & 0xFFFF)
    
             // add the bits to the accumulator, shifting over the right amount
             total_bits_shifted = x + y
             for (bits = 0; bits < total_bits_shifted + 32; bits += 31)
                 accumulator += (temp >> (bits - total_bits_shifted)) & 0x7FFFFFFF
    
             // do modulus if it overflows
             if (accumulator > 0x7FFFFFFFF)
                 accumulator = (accumulator >> 31) + (accumulator & 0x7FFFFFFF);
    

    It's late, so the accumulator part of that probably won't work. I think in principle it's right though. Someone feel free to edit this to make it right.

    Unrolled, this is pretty fast, as well, which is what the PRNG use, I'm guessing.

    [1]: x*10000 ≡ x*(9999+1) ≡ 9999*x + x ≡ x (mod 9999)