c++arraysalgorithmdata-structureswavelet

Finding number of distinct (unique) values in a sub-array for multiple queries


I have an array (which can have 2X10^5 values). I want to perform a large number of queries on this array. Each query is of the type [L,R] and the result of this query should be the number of unique values in the sub-array starting from index L and ending at index R.

I know that this can be done using Mo's algorithm in O(nrootn) time. However the catch is that Mo's algorithm is an offline algorithm. What I am looking for is an online algorithm as the result of the previous query determines the next query in my case.

I tried using to form a segment tree in which the nodes will store all the distinct elements in the range. However, this turned out to be too slow for my purpose. The preprocessing is taking too much time by this method.


Solution

  • Here's my C++ attempt at a solution (also posted here) using a Wavelet tree, implemented with code adapted from https://www.geeksforgeeks.org/wavelet-trees-introduction. The idea to reformulate the problem (as Photon commented a link to) is to first construct an array that lists for each corresponding cell in the original array, the index of the next duplicate element to the right. Then the problem becomes finding how many cells in the interval have such a "next-index" that's beyond the current interval (those clearly have no duplicate within the interval), which can be queried with a decorated Wavelet tree. See (non-zero-based) query examples at the bottom.

    // Adapted from https://www.geeksforgeeks.org/wavelet-trees-introduction
    
    #include <iostream>
    #include <vector>
    #include <map>
    #include <algorithm>
    #include <climits>
    using namespace std;
    
    // wavelet tree class 
    class wavelet_tree { 
    public: 
        // Range to elements 
        int low, high; 
    
        // Left and Right child 
        wavelet_tree* l, *r; 
    
        std::vector<int> freq;
    
        // Default constructor 
        // Array is in range [x, y] 
        // Indices are in range [from, to] 
        wavelet_tree(int* from, int* to, int x, int y) 
        { 
            // Initialising low and high 
            low = x, high = y; 
    
            // Array is of 0 length 
            if (from >= to) 
                return; 
    
            // Array is homogenous 
            // Example : 1 1 1 1 1 
            if (high == low) { 
                // Assigning storage to freq array 
                freq.reserve(to - from + 1); 
    
                // Initialising the Freq array 
                freq.push_back(0); 
    
                // Assigning values 
                for (auto it = from; it != to; it++) 
    
                    // freq will be increasing as there'll 
                    // be no further sub-tree 
                    freq.push_back(freq.back() + 1); 
    
                return; 
            } 
    
            // Computing mid 
            int mid = (low + high) / 2; 
    
            // Lambda function to check if a number 
            // is less than or equal to mid 
            auto lessThanMid = [mid](int x) { 
                return x <= mid; 
            }; 
    
            // Assigning storage to freq array 
            freq.reserve(to - from + 1); 
    
            // Initialising the freq array 
            freq.push_back(0); 
    
            // Assigning value to freq array 
            for (auto it = from; it != to; it++) 
    
                // If lessThanMid returns 1(true), we add 
                // 1 to previous entry. Otherwise, we add 0 
                // (element goes to right sub-tree) 
                freq.push_back(freq.back() + lessThanMid(*it));      
    
            // std::stable_partition partitions the array w.r.t Mid 
            auto pivot = std::stable_partition(from, to, lessThanMid); 
    
            // Left sub-tree's object 
            l = new wavelet_tree(from, pivot, low, mid); 
    
            // Right sub-tree's object 
            r = new wavelet_tree(pivot, to, mid + 1, high); 
        } 
    
        // Count of numbers in range[L..R] less than 
        // or equal to k 
        int kOrLess(int l, int r, int k) 
        { 
            // No elements int range is less than k 
            if (l > r or k < low) 
                return 0; 
    
            // All elements in the range are less than k 
            if (high <= k) 
                return r - l + 1; 
    
            // Computing LtCount and RtCount 
            int LtCount = freq[l - 1]; 
            int RtCount = freq[r]; 
    
            // Answer is (no. of element <= k) in 
            // left + (those <= k) in right 
            return (this->l->kOrLess(LtCount + 1, RtCount, k) + 
                this->r->kOrLess(l - LtCount, r - RtCount, k)); 
        } 
    
        // Count of numbers in range[L..R] greater than 
        // or equal to k 
        int kOrMore(int l, int r, int k) 
        { 
            // No elements int range are greater than k 
            if (l > r or k > high) 
                return 0; 
    
            // All elements in the range are greater than k 
            if (low >= k) 
                return r - l + 1; 
    
            // Computing LtCount and RtCount 
            int LtCount = freq[l - 1]; 
            int RtCount = freq[r]; 
    
            // Answer is (no. of element <= k) in 
            // left + (those <= k) in right 
            return (this->l->kOrMore(LtCount + 1, RtCount, k) + 
                this->r->kOrMore(l - LtCount, r - RtCount, k)); 
        }
    
    }; 
    
    
    int main() 
    { 
        int size = 7, high = INT_MIN;
        int arr[] = {1, 2, 3, 2, 4, 3, 1};
        int next[size];
        std::map<int, int> next_idx;
    
        for (int i=size-1; i>=0; i--){
            if (next_idx.find(arr[i]) == next_idx.end())
                next[i] = size + 1;
            else
                next[i] = next_idx[arr[i]];
            next_idx[arr[i]] = i + 1;
            high = max(high, next[i]);
        } 
    
        // Object of class wavelet tree 
        wavelet_tree obj(next, next + size, 1, high);
    
        // Queries are NON-zero-based
        //
        //  1  2  3  4  5  6  7
        // {1, 2, 3, 2, 4, 3, 1};
        // query([3, 6]) = 3;
        cout << obj.kOrMore(3, 6, 7) << '\n';
        // query([1, 4]) = 3;
        cout << obj.kOrMore(1, 4, 5) << '\n';
        // query([1, 7]) = 4;
        cout << obj.kOrMore(1, 7, 8) << '\n';
    
        return 0; 
    }