pythonalgorithmdata-structuressegment-tree

Does a recursive segment tree require more space than an iterative one?


I'm learning the segment tree data structure.

I have seen several iterative segment trees that use only 2n space. So I tried to use the same build method in a segment tree with recursive update and sumRange. Is this not allowed? Why can an iterative seg tree be stored in 2n but a recursive one needs 4n? Or do I just have an implementation flaw in my non-working tree?

For my 2n tree, I'm using a 1-indexed tree, so nothing is stored at tree[0]. This means the root is at tree[1]. I make recursive calls using initial range 1 to n - 1, which I'm not sure about. I get different wrong answers when I make it go to self.n or start at 0. I also get different wrong answers if I pass in index+1, left+1 or right+1

Here is my implementation:

class NumArray:
    # Classic Segment Tree

    def __init__(self, nums: List[int]):
        self.n = len(nums)
        self.tree = [0] * self.n * 2
        self.build(nums)

    def build(self, nums):
        # leaves
        for i in range(self.n):
            self.tree[i + self.n] = nums[i]

        # internal
        for i in range(self.n - 1, 0, -1):
            self.tree[i] = self.tree[i * 2] + self.tree[i * 2 + 1]

    def merge(self, left, right):
        return left + right

    def _update(self, tree_idx, seg_left, seg_right, i, val):
        # leaf
        if seg_left == seg_right:
            self.tree[tree_idx] = val
            return

        mid = (seg_left + seg_right) // 2
        if i > mid:
            self._update(tree_idx * 2 + 1, mid + 1, seg_right, i, val)
        else:
            self._update(tree_idx * 2, seg_left, mid, i, val)

        self.tree[tree_idx] = self.merge(self.tree[tree_idx * 2], self.tree[tree_idx * 2 + 1])

    def update(self, index: int, val: int) -> None:
        self._update(1, 1, self.n - 1, index, val)

    def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
        # segment out of query bounds
        if seg_left > query_right or seg_right < query_left:
            return 0

        # segment fully in bounds
        if seg_left >= query_left and seg_right <= query_right:
            return self.tree[tree_idx]

        # segment partially in bounds
        mid = (seg_left + seg_right) // 2

        # this is not necessary for correctness, but helps with efficiency (we only go down 1 path if 2 is unnecessary)
        if query_left > mid:
            return self._sumRange(tree_idx * 2 + 1, mid + 1, seg_right, query_left, query_right)
        elif query_right <= mid:
            return self._sumRange(tree_idx * 2, seg_left, mid, query_left, query_right)

        left_sum = self._sumRange(tree_idx * 2, seg_left, mid, query_left, query_right)
        right_sum = self._sumRange(tree_idx * 2 + 1, mid + 1, seg_right, query_left, query_right)

        return self.merge(left_sum, right_sum)

    def sumRange(self, left: int, right: int) -> int:
        return self._sumRange(1, 1, self.n - 1, left, right)

I do have a working fully recursive 0-indexed version, but it uses double the space

class NumArray:
    # Classic Segment Tree
    # 0-indexed recursive

    def __init__(self, nums: List[int]):
        self.n = len(nums)
        self.tree = [0] * self.n * 4
        self.build(nums, 0, 0, self.n - 1)

    def build(self, nums, tree_idx, left, right):
        # leaf
        if left == right:
            self.tree[tree_idx] = nums[left]
            return

        mid = (left + right) // 2
        self.build(nums, tree_idx * 2 + 1, left, mid)
        self.build(nums, tree_idx * 2 + 2, mid + 1, right)

        self.tree[tree_idx] = self.tree[tree_idx * 2 + 1] + self.tree[tree_idx * 2 + 2]

    def merge(self, left, right):
        return left + right

    def _update(self, tree_idx, seg_left, seg_right, i, val):
        # leaf
        if seg_left == seg_right:
            self.tree[tree_idx] = val
            return

        mid = (seg_left + seg_right) // 2
        if i > mid:
            self._update(tree_idx * 2 + 2, mid + 1, seg_right, i, val)
        else:
            self._update(tree_idx * 2 + 1, seg_left, mid, i, val)

        self.tree[tree_idx] = self.merge(self.tree[tree_idx * 2 + 1], self.tree[tree_idx * 2 + 2])

    def update(self, index: int, val: int) -> None:
        self._update(0, 0, self.n - 1, index, val)

    def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
        # segment out of query bounds
        if seg_left > query_right or seg_right < query_left:
            return 0

        # segment fully in bounds
        if seg_left >= query_left and seg_right <= query_right:
            return self.tree[tree_idx]

        # segment partially in bounds
        mid = (seg_left + seg_right) // 2

        # this is not necessary for correctness, but helps with efficiency (we only go down 1 path if 2 is unnecessary)
        if query_left > mid:
            return self._sumRange(tree_idx * 2 + 2, mid + 1, seg_right, query_left, query_right)
        elif query_right <= mid:
            return self._sumRange(tree_idx * 2 + 1, seg_left, mid, query_left, query_right)

        left_sum = self._sumRange(tree_idx * 2 + 1, seg_left, mid, query_left, query_right)
        right_sum = self._sumRange(tree_idx * 2 + 2, mid + 1, seg_right, query_left, query_right)

        return self.merge(left_sum, right_sum)

    def sumRange(self, left: int, right: int) -> int:
        return self._sumRange(0, 0, self.n - 1, left, right)

This website verifies if a segment tree implementation is correct

Also I know recursive uses more call stack space. That's not what my question is about


Solution

  • In fact, it is easy to modify the standard recursive segment tree to use 2N instead of 4N nodes. It is common to see an array of 4N nodes allocated because it is simple and doesn't require much thought, but there can be a lot of wasted space that isn't used.

    Note that a segment tree is full binary tree with N leaves, so there must be N - 1 internal nodes (this is easily seen by induction). Thus, we really only require 2N - 1 nodes and we can optimize the memory usage by changing the order in which we store the nodes.

    After any node, we can store its entire left subtree, followed by the entire right subtree in consecutive positions in the array representing the segment tree. For a node at index i, its left child directly follows it at index i + 1. Its right child directly follows the last element of the left subtree. The left subtree is also a full binary tree and it represents the range [left, mid], so there are mid - left + 1 leaf nodes. Thus, in total there are 2 * (mid - left + 1) - 1 nodes in the left subtree. We add 1 to move to the right child, so the right child is at index i + 2 * (mid - left + 1). With this in mind, the rest of the segment tree operations stay the same.

    class NumArray:
        def __init__(self, nums: List[int]):
            self.n = len(nums)
            self.tree = [0] * 2 * self.n
            self.build(nums, 0, 0, self.n - 1)
    
        def merge(self, left_val, right_val):
            return left_val + right_val
    
        def build(self, nums, tree_idx, left, right):
            if left == right:
                self.tree[tree_idx] = nums[left]
            else:
                mid = left + right >> 1
                self.tree[tree_idx] = self.merge(self.build(nums, tree_idx + 1, left, mid), 
                    self.build(nums, tree_idx + 2 * (mid - left + 1), mid + 1, right))
            return self.tree[tree_idx]
    
        def _update(self, tree_idx, left, right, i, val):
            if left == right:
                self.tree[tree_idx] = val
            else:
                mid = left + right >> 1
                if i > mid:
                    self._update(tree_idx + 2 * (mid - left + 1), mid + 1, right, i, val)
                else:
                    self._update(tree_idx + 1, left, mid, i, val)
                self.tree[tree_idx] = self.merge(self.tree[tree_idx + 1], self.tree[tree_idx + 2 * (mid - left + 1)])
    
        def update(self, index: int, val: int) -> None:
            return self._update(0, 0, self.n - 1, index, val)
    
        def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
            if query_left > query_right:
                return 0
            if seg_left == query_left and seg_right == query_right:
                return self.tree[tree_idx]
            mid = seg_left + seg_right >> 1
            return self.merge(self._sumRange(tree_idx + 1, seg_left, mid, query_left, min(mid, query_right)),
                self._sumRange(tree_idx + 2 * (mid - seg_left + 1), mid + 1, seg_right, max(mid + 1, query_left), query_right))
    
        def sumRange(self, left: int, right: int) -> int:
            return self._sumRange(0, 0, self.n - 1, left, right)