algorithmsegment-tree

Update in Segment Tree


I am learning segment tree , i came across this question. There are Array A and 2 type of operation

1. Find the Sum in Range L to R 
2. Update the Element in Range L to R by Value X.

Update should be like this

A[L] = 1*X;
A[L+1] = 2*X;
A[L+2] = 3*X;
A[R] = (R-L+1)*X;

How should i handle the second type of query can anyone please give some algorithm to modify by segment tree , or there is a better solution


Solution

  • So, it is needed to update efficiently the interval [L,R] with to the corresponding values of the arithmetic progression with the step X, and to be able to find efficiently the sums over the different intervals.

    In order to solve this problem efficiently - let's make use of the Segment Tree with Lazy Propagation.

    The basic ideas are following:

    So, the node of the Segment Tree for given problem will have structure:

    class Node {
        int left; // Left boundary of the current SegmentTree node
        int right; // Right boundary of the current SegmentTree node
    
        int sum; // Sum on the interval [left,right]
    
        int first; // First item of arithmetic progression inside given node
        int last; // Last item of arithmetic progression
    
        Node left_child;
        Node right_child;
    
        // Constructor
        Node(int[] arr, int l, int r) { ... }
    
        // Add arithmetic progression with step X on the interval [l,r]
        // O(log(N))
        void add(int l, int r, int X) { ... }
    
        // Request the sum on the interval [l,r]
        // O(log(N))
        int query(int l, int r) { ... }
    
        // Lazy Propagation
        // O(1)
        void propagate() { ... }
    }
    

    The specificity of the Segment Tree with Lazy Propagation is such, that every time, when the node of the tree is traversed - the Lazy Propagation routine (which has complexity O(1)) is executed for the given node. So, below is provided the illustration of the Lazy Propagation logic for some arbitrary node, which has children:

    enter image description here

    As you can see, during the Lazy Propagation the first and the last items of the arithmetic progressions of the child nodes are updated, also the sum inside the parent node is updated as well.

    Implementation

    Below provided the Java implementation of the described approach (with additional comments):

    class Node {
        int left; // Left boundary of the current SegmentTree node
        int right; // Right boundary of the current SegmentTree node
        int sum; // Sum on the interval
        int first; // First item of arithmetic progression
        int last; // Last item of arithmetic progression
        Node left_child;
        Node right_child;
    
        /**
         * Construction of a Segment Tree
         * which spans over the interval [l,r]
         */
        Node(int[] arr, int l, int r) {
            left = l;
            right = r;
            if (l == r) { // Leaf
                sum = arr[l];
            } else { // Construct children
                int m = (l + r) / 2;
                left_child = new Node(arr, l, m);
                right_child = new Node(arr, m + 1, r);
                // Update accumulated sum
                sum = left_child.sum + right_child.sum;
            }
        }
    
        /**
         * Lazily adds the values of the arithmetic progression
         * with step X on the interval [l, r]
         * O(log(N))
         */
        void add(int l, int r, int X) {
            // Lazy propagation
            propagate();
            if ((r < left) || (right < l)) {
                // If updated interval doesn't overlap with current subtree
                return;
            } else if ((l <= left) && (right <= r)) {
                // If updated interval fully covers the current subtree
                // Update the first and last items of the arithmetic progression
                int first_item_offset = (left - l) + 1;
                int last_item_offset = (right - l) + 1;
                first = X * first_item_offset;
                last = X * last_item_offset;
                // Lazy propagation
                propagate();
            } else {
                // If updated interval partially overlaps with current subtree
                left_child.add(l, r, X);
                right_child.add(l, r, X);
                // Update accumulated sum
                sum = left_child.sum + right_child.sum;
            }
        }
    
        /**
         * Returns the sum on the interval [l, r]
         * O(log(N))
         */
        int query(int l, int r) {
            // Lazy propagation
            propagate();
            if ((r < left) || (right < l)) {
                // If requested interval doesn't overlap with current subtree
                return 0;
            } else if ((l <= left) && (right <= r)) {
                // If requested interval fully covers the current subtree
                return sum;
            } else {
                // If requested interval partially overlaps with current subtree
                return left_child.query(l, r) + right_child.query(l, r);
            }
        }
    
        /**
         * Lazy propagation
         * O(1)
         */
        void propagate() {
            // Update the accumulated value
            // with the sum of Arithmetic Progression
            int items_count = (right - left) + 1;
            sum += ((first + last) * items_count) / 2;
            if (right != left) { // Current node is not a leaf
                // Calculate the step of the Arithmetic Progression of the current node
                int step = (last - first) / (items_count - 1);
                // Update the first and last items of the arithmetic progression
                // inside the left and right subtrees
                // Distribute the arithmetic progression between child nodes
                // [a(1) to a(N)] -> [a(1) to a(N/2)] and [a(N/2+1) to a(N)]
                int mid = (items_count - 1) / 2;
                left_child.first += first;
                left_child.last += first + (step * mid);
                right_child.first += first + (step * (mid + 1));
                right_child.last += last;
            }
            // Reset the arithmetic progression of the current node
            first = 0;
            last = 0;
        }
    }
    

    The Segment Tree in provided solution is implemented explicitly - using objects and references, however it can be easily modified in order to make use of the arrays instead.

    Testing

    Below provided the randomized tests, which compare two implementations:

    The Java implementation of the randomized tests:

    public static void main(String[] args) {
        // Initialize the random generator with predefined seed,
        // in order to make the test reproducible
        Random rnd = new Random(1);
    
        int test_cases_num = 20;
        int max_arr_size = 100;
        int num_queries = 50;
        int max_progression_step = 20;
    
        for (int test = 0; test < test_cases_num; test++) {
            // Create array of the random length
            int[] arr = new int[rnd.nextInt(max_arr_size) + 1];
            Node segmentTree = new Node(arr, 0, arr.length - 1);
    
            for (int query = 0; query < num_queries; query++) {
                if (rnd.nextDouble() < 0.5) {
                    // Update on interval [l,r]
                    int l = rnd.nextInt(arr.length);
                    int r = rnd.nextInt(arr.length - l) + l;
                    int X = rnd.nextInt(max_progression_step);
                    update_sequential(arr, l, r, X); // O(N)
                    segmentTree.add(l, r, X); // O(log(N))
                }
                else {
                    // Request sum on interval [l,r]
                    int l = rnd.nextInt(arr.length);
                    int r = rnd.nextInt(arr.length - l) + l;
                    int expected = query_sequential(arr, l, r); // O(N)
                    int actual = segmentTree.query(l, r); // O(log(N))
                    if (expected != actual) {
                        throw new RuntimeException("Results are different!");
                    }
                }
            }
        }
        System.out.println("All results are equal!");
    }
    
    static void update_sequential(int[] arr, int left, int right, int X) {
        for (int i = left; i <= right; i++) {
            arr[i] += X * ((i - left) + 1);
        }
    }
    
    static int query_sequential(int[] arr, int left, int right) {
        int sum = 0;
        for (int i = left; i <= right; i++) {
            sum += arr[i];
        }
        return sum;
    }