pythonlistflatten

Getting a flat view of a nested list


In Python, is it possible to get a flat view of a list of lists that dynamically adapts to changes to the original, nested list?

To be clear, I am not looking for a static snapshot, but for a view that reflects changes.

Further, the sub-lists should not be restricted to a primitive type, but be able to contain arbitrary objects, and not tied to a static size, but be allowed to shrink or expand freely.

Simple example:

a = ["a", "b", "c"]
b = ["d", "e", "f"]
view = flat_view([a, b])
# `view` should show ["a", "b", "c", "d", "e", "f"]
b[0] = "x"
# `view` should show ["a", "b", "c", "x", "e", "f"]

The implementation of flat_view() is what I'm looking for.


Solution

  • You would need to create a class that holds a reference to the original lists.

    You don't want a copy of the lists, you just need a reference. This class knows how to access and update the values at each of the lists it holds.

    You can access any item in the list in O(log n) search complexity (binary search) without using any additional memory to store a flattend list.

    Implementation

    Imports for types to follow:

    from typing import Any, Callable, Iterable, List, Tuple, Union
    

    If you want the view to listen to changes in the underlying lists, you will need to create an ListWrapper that can delegate to the underlying list and notify the view that things changed in the lists.

    We want to make sure self._notify() is called whenever a list changes.

    class ListWrapper:
        def __init__(self, lst: List[Any]):
            self._list: List[Any] = lst
            self.callbacks: List[Callable[[], None]] = []
    
        def __getitem__(self, index: int) -> Any:
            return self._list[index]
    
        def __setitem__(self, index: int, value: Any) -> None:
            self._list[index] = value
            self._notify()
    
        def __len__(self) -> int:
            return len(self._list)
    
        def append(self, item: Any) -> None:
            self._list.append(item)
            self._notify()
    
        def extend(self, iterable: Iterable[Any]) -> None:
            self._list.extend(iterable)
            self._notify()
    
        def insert(self, index: int, item: Any) -> None:
            self._list.insert(index, item)
            self._notify()
    
        def remove(self, item: Any) -> None:
            self._list.remove(item)
            self._notify()
    
        def pop(self, index: int = -1) -> Any:
            item = self._list.pop(index)
            self._notify()
            return item
    
        def clear(self) -> None:
            self._list.clear()
            self._notify()
    
        def _notify(self) -> None:
            for callback in self.callbacks:
                callback()
    
        def add_callback(self, callback: Callable[[], None]) -> None:
            self.callbacks.append(callback)
    

    Here is a class which provides a "flat view" of multiple lists, where updates to the original lists are reflected in the flat view.

    class FlatView:
        def __init__(self, lists: List[ListWrapper]) -> None:
            self.lists = lists
            self.update_lengths()
            for lst in self.lists:
                lst.add_callback(self.update_lengths)
    
        def update_lengths(self) -> None:
            self.sub_lengths = self._compute_overall_length()
    
        def _compute_overall_length(self) -> List[int]:
            lengths = [0]
            for lst in self.lists:
                lengths.append(lengths[-1] + len(lst))
            return lengths
    
        def __getitem__(self, index: int) -> Any:
            if index < 0 or index >= len(self):
                raise IndexError("list index out of range")
            list_index = self._find_list_index(index)
            sublist_index = index - self.sub_lengths[list_index]
            return self.lists[list_index][sublist_index]
    
        def _find_list_index(self, index: int) -> int:
            # Binary search to find the list that contains the index
            low, high = 0, len(self.sub_lengths) - 1
            while low < high:
                mid = (low + high) // 2
                if self.sub_lengths[mid] <= index < self.sub_lengths[mid + 1]:
                    return mid
                elif index < self.sub_lengths[mid]:
                    high = mid
                else:
                    low = mid + 1
            return low
    
        def __len__(self) -> int:
            return self.sub_lengths[-1]
    
        def __repr__(self) -> str:
            return repr([item for lst in self.lists for item in lst])
    

    Usage

    The following wrapper function creates a FlatView instance for the provided list of lists.

    We wrap each list in a ListWrapper so that we can attach a callback function to update the view's overall lengths that are used to access the data.

    def flat_view(lists: List[List[Any]]) -> Tuple[FlatView, List[ListWrapper]]:
        wrappers = [ListWrapper(lst) for lst in lists]
        return FlatView(wrappers), wrappers
    

    Here is how you would use it. Please note that we need to modify the wrappers for the view to understand how the lists change.

    For example, b[0] = "x" would not work.

    if __name__ == "__main__":
        a = ["a", "b", "c"]
        b = ["d", "e", "f"]
        view, [a_wrapper, b_wrapper] = flat_view([a, b])
        
        print(view)                # Output: ['a', 'b', 'c', 'd', 'e', 'f']
        
        b_wrapper[0] = "x"
        print(view)                # Output: ['a', 'b', 'c', 'x', 'e', 'f']
        
        print(view[3], len(view))  # Output: 'x' 6
        
        a_wrapper.append("y")
        print(view, len(view))     # Output: ['a', 'b', 'c', 'y', 'x', 'e', 'f'] 7
    
        print(a, len(a))           # Output: ['a', 'b', 'c', 'y'] 4
        print(b, len(b))           # Output: ['x', 'e', 'f'] 3