c++optimizationsubset-sum

Need optimization tips for a subset sum like problem with a big constraint


Given a number 1 <= N <= 3*10^5, count all subsets in the set {1, 2, ..., N-1} that sum up to N. This is essentially a modified version of the subset sum problem, but with a modification that the sum and number of elements are the same, and that the set/array increases linearly by 1 to N-1.

I think i have solved this using dp ordered map and inclusion/exclusion recursive algorithm, but due to the time and space complexity i can't compute more than 10000 elements.

#include <iostream>
#include <chrono>
#include <map>
#include "bigint.h"


using namespace std;

//2d hashmap to store values from recursion; keys- i & sum; value- count
map<pair<int, int>, bigint> hmap;

bigint counter(int n, int i, int sum){

    //end case
    if(i == 0){ 
        if(sum == 0){
            return 1;
        }
        return 0;
    }

    //alternative end case if its sum is zero before it has finished iterating through all of the possible combinations
    if(sum == 0){
        return 1;
    }

    //case if the result of the recursion is already in the hashmap
    if(hmap.find(make_pair(i, sum)) != hmap.end()){
        return hmap[make_pair(i, sum)];
    }

    //only proceed further recursion if resulting sum wouldnt be negative
    if(sum - i < 0){
        //optimization that skips unecessary recursive branches
        return hmap[make_pair(i, sum)] = counter(n, sum, sum);
    }
    else{
                                        //include the number   dont include the number
        return hmap[make_pair(i, sum)] = counter(n, i - 1, sum - i) + counter(n, i - 1, sum);
    }
}

The function has starting values of N, N-1, and N, indicating number of elements, iterator(which decrements) and the sum of the recursive branch(which decreases with every included value).

This is the code that calculates the number of the subsets. for input of 3000 it takes around ~22 seconds to output the result which is 40 digits long. Because of the long digits i had to use an arbitrary precision library bigint from rgroshanrg, which works fine for values less than ~10000. Testing beyond that gives me a segfault on line 28-29, maybe due to the stored arbitrary precision values becoming too big and conflicting in the map. I need to somehow up this code so it can work with values beyond 10000 but i am stumped with it. Any ideas or should i switch towards another algorithm and data storage?


Solution

  • Here is a different algorithm, described in a paper by Evangelos Georgiadis, "Computing Partition Numbers q(n)":

    std::vector<BigInt> RestrictedPartitionNumbers(int n)
    {
        std::vector<BigInt> q(n, 0);
        // initialize q with A010815
        for (int i = 0; ; i++)
        {
            int n0 = i * (3 * i - 1) >> 1;
            if (n0 >= q.size())
                break;
            q[n0] = 1 - 2 * (i & 1);
            int n1 = i * (3 * i + 1) >> 1;
            if (n1 < q.size())
                q[n1] = 1 - 2 * (i & 1);
        }
        // construct A000009 as per "Evangelos Georgiadis, Computing Partition Numbers q(n)"
        for (size_t k = 0; k < q.size(); k++)
        {
            size_t j = 1;
            size_t m = k + 1;
            while (m < q.size())
            {
                if ((j & 1) != 0)
                    q[m] += q[k] << 1;
                else
                    q[m] -= q[k] << 1;
                j++;
                m = k + j * j;
            }
        }
        return q;
    }
    

    It's not the fastest algorithm out there, and this took about half a minute for on my computer for n = 300000. But you only need to do it once (since it computes all partition numbers up to some bound) and it doesn't take a lot of memory (a bit over 150MB).

    The results go up to but excluding n, and they assume that for each number, that number itself is allowed to be a partition of itself eg the set {4} is a partition of the number 4, in your definition of the problem you excluded that case so you need to subtract 1 from the result.

    Maybe there's a nicer way to express A010815, that part of the code isn't slow though, I just think it looks bad.