rustmodulointeger-overflowint128

How to resolve a possible multiplicative overflow to get correct modulus operation?


I have to perform (a * b) % m, but a, b, and m are 128-bit unsigned types, and overflow during multiplication is a large possibility. How can I still get a correct answer (probably using % more)?

I'm trying to implement the modular exponent function in Rust, where the largest built-in type is u128 (which is the max I can use). All three variables are really large, and so (a * b) > 2^128 is easy. I can use a.overflowing_mul(b) to detect whether an overflow occurred, but I do not know how to go back from the overflowed result (which can be thought of as (a * b) % 2^128) to get (a * b) % m.

My modular exponent code looks like this (currently no overflowing support is added):

fn mod_exp(b: u128, e: u128, m: u128) {
    (0..e).fold(1, |x, _| (x * b) % m)
    //                    ^^^^^^^^^^^
}

From a mathematical perspective:

(a * b) % m IS ACTUALLY (a * b) % B % m
| B = current base (2^128)

Examples:

// Mathematical
(9 * 13) % 11 = 7
// Real (base 20):
(9 * 13) % (B = 20) % 11 = 6
         ^^^^^^^^^^        ^ should be 7

(8 * 4) % 14 = 4
(8 * 4) % (B = 16) % 14 = 0
        ^^^^^^^^^^        ^ should be 4

Solution

  • This implementation, based on splitting the 128-bit product into four 64-bit products, is five times as fast as num_bigint::BigUint, ten times as fast as uint::U256, and 2.3 times as fast as gmp::mpz::Mpz:

    fn mul_mod(a: u128, b: u128, m: u128) -> u128 {
        if m <= 1 << 64 {
            ((a % m) * (b % m)) % m
        } else {
            let add = |x: u128, y: u128| x.checked_sub(m - y).unwrap_or_else(|| x + y);
            let split = |x: u128| (x >> 64, x & !(!0 << 64));
            let (a_hi, a_lo) = split(a);
            let (b_hi, b_lo) = split(b);
            let mut c = a_hi * b_hi % m;
            let (d_hi, d_lo) = split(a_lo * b_hi);
            c = add(c, d_hi);
            let (e_hi, e_lo) = split(a_hi * b_lo);
            c = add(c, e_hi);
            for _ in 0..64 {
                c = add(c, c);
            }
            c = add(c, d_lo);
            c = add(c, e_lo);
            let (f_hi, f_lo) = split(a_lo * b_lo);
            c = add(c, f_hi);
            for _ in 0..64 {
                c = add(c, c);
            }
            add(c, f_lo)
        }
    }
    

    (Warning: none of these implementations are suitable for use in cryptographic code, since they are not hardened against side channel attacks.)