javaalgorithm

Find maximum of all minimum sum


Given integer arrays A and B of size n

Find all possible combinations (subsets) of size 1, 2, 3, ..., n. We get 2 power n possibilities

For each combination (subset) of A, find the minimum in that subset, say min, and multiply this value with sum of items for that combination from Array B.

Find the maximum value from all the above minimums. Return the result in modulo (10^9+ 7).

Constraints:

n in range 1 to 10^5
1 <= A[i], B[i] <= 10^6

Example:

A = [1,1,3]
B = [1,2,2]

indices:  A      min         B    sum         min*sum
                 (subset A)       (subset B)
--------------------------------------------------------------
[0]       1       1          1       1        1*1=1
[1]       1       1          2       2        1*2=2
[2]       3       3          2       2        3*2=6
[0,1]     1,1     1          1,2     3        1*3=3
[0,2]     1,3     1          1,2     3        1*3=3
[1,2]     1,3     1          2,2     4        1*4=4
[0,1,2]   1,1,3   1          1,2,2   5        1*5=5

Result = max( all min*sum) = max(1,2,6,3,3,4,5) = 6

Here is my code:

public static int solve(List<Integer> A, List<Integer> B) {
        int n = A.size();
        long max = 0;
    final int MOD = 1_000_000_007;
        // Loop through all subsets using bitmasking
        for (int mask = 0; mask < (1 << n); mask++) {
            int min = Integer.MAX_VALUE;
            long sum = 0;

            for (int i = 0; i < n; i++) {
                if ((mask & (1 << i)) != 0) {
                    min = Math.min(min, A.get(i));
                    sum += B.get(i);
                    sum %= MOD; // To prevent overflow
                }
            }
            long current = (int) ((long) min * sum % MOD);
            max = Math.max(max, current);
        }

        return (int) max;
     }

This was asked during a interview as part of hackerrank, when I used this code, it passed only 4 out of 15 test cases. Many failed with wrong output and time out errors.

What is the correct approach to solve this?


Solution

  • You can try all elements in A as the minimum value of some subsequence of A. To maximize the product (min in A) * (sum in B) for a fixed minimum, we can only choose to pick elements in B to add to the sum for which the element in A at the corresponding index is not less than this fixed minimum. There is additionally a special case that the element corresponding to the current element in A that we are trying as the minimum must always be taken.

    Then, we can pick the optimal elements by considering two cases (without attempting to try all subsets):

    1. If the minimum we are considering is positive, then we want to take as many positive corresponding elements in B as possible.
    2. If the minimum is negative, then we want to take as many negative elements in B as possible.

    Note that the case of the minimum being 0 can go into either case: the product will always be 0.

    We can precompute the sum of all positive elements in B with corresponding element in A larger than a particular value with a prefix sum array (after sorting a List of indexes of A in descending order based on the values in A). We can do the same for negative elements.

    Finally, for each element we can perform a binary search to find its index in the sorted version of A and use that to access the prefix sum arrays.

    Time complexity: O(N log N)

    public static int solve(List<Integer> A, List<Integer> B) {
        var sortedIndexesByA = IntStream.range(0, A.size()).boxed()
                .sorted(Comparator.comparingInt(A::get).reversed()).toList();
        long[] bSumPos = new long[B.size()], bSumNeg = new long[B.size()];
        for (int i = 0; i < B.size(); i++) {
            bSumPos[i] = (i > 0 ? bSumPos[i - 1] : 0) + 
                             Math.max(B.get(sortedIndexesByA.get(i)), 0);
            bSumNeg[i] = (i > 0 ? bSumNeg[i - 1] : 0) + 
                             Math.min(B.get(sortedIndexesByA.get(i)), 0);
        }
        long ans = Long.MIN_VALUE;
        for (int i = 0; i < A.size(); i++) {
            int low = 0, high = A.size() - 1;
            while (low <= high) {
                int mid = low + high >>> 1;
                if (A.get(sortedIndexesByA.get(mid)) < A.get(i))
                    high = mid - 1;
                else
                    low = mid + 1;
            }
            long sum = (A.get(i) > 0 ? bSumPos : bSumNeg)[high] + 
                       (A.get(i) > 0 ^ B.get(i) > 0 ? B.get(i) : 0);
            ans = Math.max(ans, sum * A.get(i));
        }
        return (int) (ans % 1000000007);
    }