I'm trying to improve the performance of the following modular exponentiation function I wrote. I feel like there might be some way of leveraging the fact that it calculates the same 64-bit modulus n of a 128-bit unsigned integer many times, up to 128 times! I don't mind sacrificing some portability and using 128-bit integer types, and I'm targeting x86_64. Is there anything faster in the general case than exponentiation by squaring?
uint64_t modpowu64(uint64_t a, uint64_t e, uint64_t n) {
// Returns a^e mod n
if (n == 0) return 0;
if (a < 2) return a;
unsigned __int128 res = 1;
unsigned __int128 sq = a % n;
while (e) {
if (e & 1ULL) res = (res * sq) % n;
sq = (sq*sq) % n;
e >>= 1;
}
return res;
}
The function above works, so maybe this isn't the right forum. Should this be on Code Review instead?
After going to town on this a bit, I've found that for odd modulus n, using Montgomery Form was much more optimisable than I initially expected and can give significant performance improvements, 2-3x in my tests.
The key realisation was that the modular multiplications (u64 * u64) in the exponentiation function loop could be reduced to a 64 bit value which is congruent to the product mod n using only simple operations, avoiding many mod n operations as the intermediate results need not be reduced all the way to the interval [0, n-1].
The magic part is as follows, where ninv is the multiplicative inverse of n mod 2^64, and twoto64modn is as the name suggests, both of which only need to be calculated once per change of exponentiation modulus.
unsigned __int128 prod128 = (unsigned __int128)ar*br; // 0 <= prod128 <= 2^128 - 2^65 + 1
uint64_t m = prod128 & 0xffffffffffffffffULL;
m *= -ninv; // 0 <= m < 2^64
unsigned __int128 mn128 = (unsigned __int128)m*n; // 0 <= mn128 <= n*(2^64 - 1) && mn128 = -arbr mod 2^64
unsigned __int128 sum = prod128 + mn128; // 0 <= sum <= 2^128 + (n - 3)2^64
unsigned __int128 max128 = (prod128 > mn128 ? prod128 : mn128);
uint64_t sumu64 = sum >> 64; // 0 <= sum >> 64 <= 2^64 + (n - 3)
// twoto64modn = 2^64 - xn for some x >= 1, so if sum overflowed,
// 0 <= sumu64 + twoto64modn <= (n - 3) + 2^64 - xn = 2^64 - (x-1)n - 3 < 2^64
return sumu64 + (sum < max128)*twoto64modn;
The full code I've posted into a GitHub repo. here https://github.com/FastAsChuff/Fast-Modular-Exponentiation/tree/main