pythonalgorithmperformancehashmap

How to loop through all distinct triplets of an array such that they are of the format (a, b, b)? Length of array <= 10^6


As stated above, I need to efficiently count the number of distinct triplets of the form (a, b, b). In addition, the triplet is only valid if and only if it can be formed by deleting some integers from the array, only leaving behind that triplet in that specific ordering. What this is saying is that the triplets need to be in chronological order, I believe, but don't have to consist of consecutive elements. The solution needs to be really efficient as N (the length of the array) can go upto 10^6 (or a million). For example, if the array was [5, 6, 7, 3, 3, 3], then the answer would be 3 as the triplets would be: (5, 3, 3), (6, 3, 3), and (7, 3, 3).

This was my first brute force (just to start off, O(n^3)):

n = int(input())
arr = list(map(int, input().split()))

ans = set()
for i in range(n):
    for j in range(i + 1, n):
        if arr[i] != arr[j]:
            for k in range(j + 1, n):
                if arr[j] == arr[k]:
                    ans.add((arr[i], arr[j], arr[k]))

print(len(ans))

Then, I unsuccessfully tried optimizing this to an O(n^2), which is still too slow, but I can't even seem to get this right:

def solve():
    n = int(input())
    arr = list(map(int, input().split()))

    freq = Counter(arr)
    ans = set()
    for a in freq:
        if freq[a] < 1:
            continue
        for b in freq:
            if b != a and freq[b] >= 2:
                ans.add((a, b, b))

    return len(ans)


print(solve())

I can't fix the logic for the O(n^2) and optimize this further to fully solve the problem under the given constraints. Assistance would be much appreciated.


Solution

  • At the second-to-last occurrence of each b-value, add the number of different values that came before it. Takes about 1.5 seconds for array length 10^6.

    from collections import Counter
    
    def linear(arr):
        ctr = Counter(arr)
        A = set()
        result = 0
        for b in arr:
            if ctr[b] == 2:
                result += len(A) - (b in A)
            ctr[b] -= 1
            A.add(b)
        return result
    

    Testing your small example and five larger arrays:

    import random
    from time import time
    
    def original(arr):
        n = len(arr)
        ans = set()
        for i in range(n):
            for j in range(i + 1, n):
                if arr[i] != arr[j]:
                    for k in range(j + 1, n):
                        if arr[j] == arr[k]:
                            ans.add((arr[i], arr[j], arr[k]))
        return len(ans)
    
    def test(arr):
        expect = original(arr)
        result = linear(arr)
        print(result == expect, expect, result)
    
    # Correctness
    test([5, 6, 7, 3, 3, 3])
    for _ in range(5):
        test(random.choices(range(100), k=100))
    
    # Speed
    n = 10**6
    arr = random.choices(random.sample(range(10**9), n), k=n)
    t = time()
    print(linear(arr))
    print(time() - t)
    

    Sample output (Attempt This Online!):

    True 3 3
    True 732 732
    True 1038 1038
    True 629 629
    True 754 754
    True 782 782
    80414828386
    1.4968228340148926