pythoniteratoriteration

For a custom Mapping class that returns self as iterator, list() returns empty. How do I fix it?


The following is a simplified version of what I am trying to do (the actual implementation has a number of nuances):

from __future__ import annotations

from collections.abc import MutableMapping

class SideDict(MutableMapping, dict):
    """
    The purpose of this special dict is to side-attach another dict. A key
    and its value from main dict are preferred over same key in the
    side-dict. If only a key is not present in main dict, then it is used
    from the side-dict.        
    """

    # The starting SideDict instance will have side_dict=None, a subsequent
    # SideDict instance can use the first instance as its side_dict.
    def __init__(self, data, side_dict: SideDict | None):
        self._store = dict(data)
        self._side_dict = side_dict

        self._iter_keys_seen = []
        self._iter_in_side_dict = False
        self._iter = None
        # Also other stuff

    # Also implements __bool__, __contains__, __delitem__, __eq__, __getitem__,
    # __missing__, __or__, __setitem__ and others.

    def __iter__(self):
        self._iter_keys_seen = []
        self._iter_in_side_dict = False
        self._iter = None
        return self

    def __next__(self):
        while True:
            # Start with an iterator that is on self._store
            if self._iter is None:
                self._iter = self._store.__iter__()

            try:
                next_ = self._iter.__next__()
                if next_ in self._iter_keys_seen:
                    continue
                # Some other stuff I do with next_
                self._iter_keys_seen.append(next_)
                return next_
            except StopIteration as e:
                if self._side_dict is None or self._iter_in_side_dict:
                    raise e
                else:
                    # Switching to side-dict iterator
                    self._iter_in_side_dict = True
                    self._iter = self._side_dict.__iter__()

    def __len__(self):
        return len([k for k in self])  # Its not the most efficient, but
                                       # I don't know any other way.

sd_0 = SideDict(data={"a": "A"}, side_dict=None)
sd_1 = SideDict(data={"b": "B"}, side_dict=sd_0)
sd_2 = SideDict(data={"c": "C"}, side_dict=sd_1)

print(len(sd_0), len(sd_1), len(sd_2))  # all work fine
print(list(sd_0))  # ! Here is the problem, shows empty list `[]` !

On putting some print()s, here is what I observed being called:

  1. list() triggers obj.__iter__() first.
  2. Followed by obj.__len__(). I vaguely understand that this is done so as to allocate optimal length of list.
  3. Because obj.__len__() has list-comprehension ([k for k in self]), it again triggers obj.__iter__().
  4. Followed by obj.__next__() multiple times as it iterates through obj._store and obj._side_dict.
  5. When obj.__next__() hits the final un-silenced StopIteration, list-comprehension in obj.__len__() ends.
  6. Here the problem starts. list() seems to be calling obj.__next__() again immediately after ending obj.__len__(), and it hits StopIteration again. There is no obj.__iter__(). And so the final result is an empty list!

What I think might be happening is that list() starts an iterator on its argument, but before doing anything else, it wants to find out the length. My __len__() uses an iterator itself, so it seems the both are using the same iterator. And then this iterator is consumed in obj.__len__(), and nothing left for outer list() to consume. Please correct me if I am wrong.

So how can I change my obj.__len__() to use a non-clashing iterator?


Solution

  • The problem is that your object is its own iterator. Most objects should not be their own iterator - it only makes sense to do that if the object's only job is to be an iterator, or if there's some other inherent reason you shouldn't be able to perform two independent loops over the same object.

    Most iterable objects should return a new iterator object from __iter__, and not implement __next__. The simplest way to do this is usually by either writing __iter__ as a generator function, or returning an iterator over some other object that happens to have the right elements. For example, using the set-like union functionality of dict key views:

    def __iter__(self):
        return iter(self._store.keys() | self._side_dict.keys())
    

    Or using a generator:

    def __iter__(self):
        yield from self._store
    
        for key in self._side_dict:
            if key not in self._store:
                yield key
    

    In this case, the generator has the advantage of not building the self._store.keys() | self._side_dict.keys() set.


    Also, unless you're writing this thing as a learning exercise, you should probably just use collections.ChainMap. It handles all of this already.