pythontime-complexitynested-listsspace-complexity

Space and time complexity of flattening a nested list of arbitrary depth


Given a python list that contains nested lists of arbitrary levels of nesting, the goal is to return a completely flattened list i.e for the sample input [1, [2], [[[3]]], 1], the output should be [1, 2, 3, 1].

My solution:

def flatten(lst):
  stack = [[lst, 0]]
  result = []

  while stack:
    current_lst, start_index = stack[-1]
    for i in range(start_index, len(current_lst)):
      if isinstance(current_lst[i], list):
        # Update the start_index of current list
        # to the next element after the nested list
        stack[-1][1] = i + 1
        # Add nested list to stack
        # and initialize its start_index to 0
        stack.append([current_lst[i], 0])
        # Pause current_lst traversal
        break
      # non list item
      # add item to result
      result.append(current_lst[i])
    else:
      # no nested list
      # remove current list from stack
      stack.pop()

  return result

I would like to know the time and space complexity (auxiliary space) of my solution if correct.

What I think

Time Complexity:

I believe the solution has a time complexity of O(m + n) where m is the total number of nested lists at all levels and n is the total number of atomic elements at all levels (non list elements)

Space Complexity:

I believe the space complexity is O(d), where d is the depth of the most nested list. This is because the stack tracks the current state of traversal, and its size is proportional to the nesting depth

Is the solution correct?

Is the time and space analysis correct?


Solution

  • Yes, the solution is correct.

    Yes, the time and space analysis are correct ... if you don't count the space used by result as auxiliary space, which is reasonable. Although note that result overallocates/reallocates, which you could regard as taking O(n) auxiliary space. You could optimize that by doing two passes over the whole input, one to count the atomic elements, then create result = [None] * n, and then another pass to fill it.

    Btw it's better to use iterators instead of your own list+index pairs (Attempt This Online!):

    def flatten(lst):
      stack = [iter(lst)]
      result = []
      while stack:
        for item in stack[-1]:
          if isinstance(item, list):
            stack.append(iter(item))
            break
          result.append(item)
        else:
          stack.pop()
      return result
    

    Or with the mentioned space optimization (Attempt This Online!):

    def flatten(lst):
      def atoms():
        stack = [iter(lst)]
        while stack:
          for item in stack[-1]:
            if isinstance(item, list):
              stack.append(iter(item))
              break
            yield item
          else:
            stack.pop()
      n = sum(1 for _ in atoms())
      result = [None] * n
      i = 0
      for result[i] in atoms():
        i += 1
      return result