javacombinatorics

The multiplicative formula for calculating the Binomial Coefficient in Java for large n and k values modulo 10^9 + 7 outputs incorrect value


I have an assignment that asks to create a program that can calculate the binomial coefficient given any n, k such that 1<=k<=n<=2000. I am able to accomplish this for small n, k, but for some reason its not accurate at larger n and k values.

Here is my current code

import java.io.BufferedInputStream;
import java.util.Scanner;

public class Main
{
    public static double bCoeffModded(double nPar, double kPar, double mPar) {
        if ((nPar == kPar) || (kPar == 0)) {
            return 1;
        }
        
        double binomialCoefficient = nPar;
        
        for (int i = 2; i <= kPar; i++) {
            binomialCoefficient *= ((nPar + 1 - i) / i);
            binomialCoefficient %= mPar;
        }
        return binomialCoefficient;
    }
    
    public static void main(String[] args) {
        // Declarations
        double n, k;
        double m = (int) Math.pow(10,9) + 7; // modulus
        int c;
        
        // Input
        Scanner reader = new Scanner(new BufferedInputStream(System.in));
        n = reader.nextInt();
        k = reader.nextInt();
        
        // Calculations
        c = (int) bCoeffModded(n, k, m);
        
        // OutPut
        System.out.println(c);
    }
}

I am applying the modulus in each iteration because it reduces the chance for an overflow. I think there is an overflow anyway because the accuracy is dropping anyways.

The formula and a overview of the binomial coefficient can be found here. I should note I already tried the recursive method and factorial method. Both have overflow issues as well.

test case 1:
6 4
expected 15
actual 15

test case 2:
100 50
expected 538992043
actual 309695578


Solution

  • Well, as I have mentioned in comments, you can't use integer division together with modular division, otherwise you will get wrong results. That means you need to find some kind of recurrent relation or algorithm which does not involve division, I do see following options:

    I. the obvious recurrent relation is based on Pascal triangle:

    (n,k) = (n-1,k-1) + (n-1,k)

    it does not involve division, hence we may perform our computations using following idea:

    mod((n,k),m) = mod((n-1,k-1),m) + mod((n-1,k),m)

    The problem is, the best time complexity we may achieve there is ~ O(n*k), moreover, we need some sufficient amount of memory to store either intermediate results or call stack.

    public static int ncr(int n, int r) {
        // todo: could be optimized to O(r) space
        int[][] interim = new int[n + 1][r + 1];
        return ncr(n, r, interim);
    }
    
    public static int ncr(int n, int r, int[][] interim) {
        int modulo = 1000000007;
        if (r == 0 || n == r) {
            return 1;
        }
        if (interim[n][r] > 0) {
            return interim[n][r];
        }
        return interim[n][r] = (ncr(n - 1, r - 1, interim) + ncr(n - 1, r, interim)) % modulo;
    }
    

    However, I do believe this approach (with some obvious optimizations) is the only one relevant from programmer perspective - others require to perform some research or existence of math background.

    II. math solution

    by definition:

    (n,k) = n*(n-1)*..*(n-k-1)/(1*2*..*k)

    here we do know how to calculate mod(n*(n-1)*..*(n-k-1),m) and that seems not to be a big deal:

    mod(mod(mod(n,m)*mod(n-1,m),m)*mod(n-2,m),m)....

    the actual challenge is what to do with 1/(1*2*..*k) part, and the main idea is: is it possible to replace division in mod(p/q, m) by multiplication, i.e. find some r that satisfies following equality: mod(p/q, m) = mod(p*r, m)? And the answer is: yes, it is possible - according to Fermat theorem it is pow(q,m-2), technically we may even calculate mod(pow(q,m-2),m) with O(logm) time complexity:

    public static int modpow(int x, int n) {
        int modulo = 1000000007;
        long product = 1;
        long p = x;
        while (n != 0) {
            if ((n & 1) == 1) {
                product = product * p % modulo;
            }
            p = (p * p % modulo);
            n >>= 1;
        }
        return (int) product;
    }
    
    public static int ncr(int n, int r) {
        int modulo = 1000000007;
    
        long mult = 1;
        for (int i = n; i > n - r; i--) {
            mult = (mult * i) % modulo;
        }
        
        for (int i = 2; i <= r; i++) {
            mult = (mult * modpow(i, modulo - 2)) % modulo;
        }
    
        return (int) mult;
    }
    

    However, as was actually mentioned by @Dawood ibn Kareem in comments: BigInteger#modInverse does return something usefull as well and we may rewrite previous code using following form (which is slower than the previous one):

    public static int ncr(int n, int r) {
        int modulo = 1000000007;
        BigInteger bm = BigInteger.valueOf(modulo);
    
        long mult = 1;
        for (int i = n; i > n - r; i--) {
            mult = (mult * i) % modulo;
        }
    
        for (int i = 2; i <= r; i++) {
            mult = (mult * BigInteger.valueOf(i).modInverse(bm).intValue()) % modulo;
        }
    
        return (int) mult;
    }
    

    on the other hand, that is too java specific and most probably we need to "find" another idea of calculating modular inverse, and Extended Euclidean algorithm comes to the rescue:

    function inverse(a, n)
        t := 0;     newt := 1
        r := n;     newr := a
    
        while newr ≠ 0 do
            quotient := r div newr
            (t, newt) := (newt, t − quotient × newt) 
            (r, newr) := (newr, r − quotient × newr)
    
        if r > 1 then
            return "a is not invertible"
        if t < 0 then
            t := t + n
    
        return t
    

    which in our case could be simplified to the following recurrent relation (note, since m % x < x we may use dynamic programming approach):

    inv(x) = (inv(m % x) * (m - m/x)) % m
    

    and now we get:

    public static int ncr(int n, int r) {
        int modulo = 1000000007;
    
        long mult = 1;
        for (int i = n; i > n - r; i--) {
            mult = (mult * i) % modulo;
        }
    
        int[] invs = new int[r + 1];
        invs[0] = 1;
        if (r > 1) {
            invs[1] = 1;
        }
    
        for (int i = 2; i <= r; i++) {
            invs[i] = (int) (((long) invs[modulo % i] * (modulo - modulo / i)) % modulo);
            mult = (mult * invs[i]) % modulo;
        }
    
        return (int) mult;
    }
    

    UPD. It seems I have missed the obvious optimization, and the fastest solution among suggested is actually following (no knowledge of extended Euclidean algorithm required):

    public static int modpow(int x, int n) {
        int modulo = 1000000007;
        long product = 1;
        long p = x;
        while (n != 0) {
            if ((n & 1) == 1) {
                product = product * p % modulo;
            }
            p = (p * p % modulo);
            n >>= 1;
        }
        return (int) product;
    }
    
    public static int ncr(int n, int r) {
        int modulo = 1000000007;
    
        long mult = 1;
        for (int i = n; i > n - r; i--) {
            mult = (mult * i) % modulo;
        }
    
        long div = 1;
        for (int i = 2; i <= r; i++) {
            div = (div * i) % modulo;
        }
    
        //int modInverse = BigInteger.valueOf(div).modInverse(BigInteger.valueOf(modulo)).intValue();
        //return (int) ((mult * modInverse) % modulo);
        return (int) ((mult * modpow((int) div, modulo - 2)) % modulo);
    }