The following is a simplified version of what I am trying to do (the actual implementation has a number of nuances):
from __future__ import annotations
from collections.abc import MutableMapping
class SideDict(MutableMapping, dict):
"""
The purpose of this special dict is to side-attach another dict. A key
and its value from main dict are preferred over same key in the
side-dict. If only a key is not present in main dict, then it is used
from the side-dict.
"""
# The starting SideDict instance will have side_dict=None, a subsequent
# SideDict instance can use the first instance as its side_dict.
def __init__(self, data, side_dict: SideDict | None):
self._store = dict(data)
self._side_dict = side_dict
self._iter_keys_seen = []
self._iter_in_side_dict = False
self._iter = None
# Also other stuff
# Also implements __bool__, __contains__, __delitem__, __eq__, __getitem__,
# __missing__, __or__, __setitem__ and others.
def __iter__(self):
self._iter_keys_seen = []
self._iter_in_side_dict = False
self._iter = None
return self
def __next__(self):
while True:
# Start with an iterator that is on self._store
if self._iter is None:
self._iter = self._store.__iter__()
try:
next_ = self._iter.__next__()
if next_ in self._iter_keys_seen:
continue
# Some other stuff I do with next_
self._iter_keys_seen.append(next_)
return next_
except StopIteration as e:
if self._side_dict is None or self._iter_in_side_dict:
raise e
else:
# Switching to side-dict iterator
self._iter_in_side_dict = True
self._iter = self._side_dict.__iter__()
def __len__(self):
return len([k for k in self]) # Its not the most efficient, but
# I don't know any other way.
sd_0 = SideDict(data={"a": "A"}, side_dict=None)
sd_1 = SideDict(data={"b": "B"}, side_dict=sd_0)
sd_2 = SideDict(data={"c": "C"}, side_dict=sd_1)
print(len(sd_0), len(sd_1), len(sd_2)) # all work fine
print(list(sd_0)) # ! Here is the problem, shows empty list `[]` !
On putting some print()
s, here is what I observed being called:
list()
triggers obj.__iter__()
first.obj.__len__()
. I vaguely understand that this is done so as to allocate optimal length of list.obj.__len__()
has list-comprehension ([k for k in self]
), it again triggers obj.__iter__()
.obj.__next__()
multiple times as it iterates through obj._store
and obj._side_dict
.obj.__next__()
hits the final un-silenced StopIteration
, list-comprehension in obj.__len__()
ends.list()
seems to be calling obj.__next__()
again immediately after ending obj.__len__()
, and it hits StopIteration
again. There is no obj.__iter__()
. And so the final result is an empty list!What I think might be happening is that list()
starts an iterator on its argument, but before doing anything else, it wants to find out the length. My __len__()
uses an iterator itself, so it seems the both are using the same iterator. And then this iterator is consumed in obj.__len__()
, and nothing left for outer list()
to consume. Please correct me if I am wrong.
So how can I change my obj.__len__()
to use a non-clashing iterator?
The problem is that your object is its own iterator. Most objects should not be their own iterator - it only makes sense to do that if the object's only job is to be an iterator, or if there's some other inherent reason you shouldn't be able to perform two independent loops over the same object.
Most iterable objects should return a new iterator object from __iter__
, and not implement __next__
. The simplest way to do this is usually by either writing __iter__
as a generator function, or returning an iterator over some other object that happens to have the right elements. For example, using the set-like union functionality of dict key views:
def __iter__(self):
return iter(self._store.keys() | self._side_dict.keys())
Or using a generator:
def __iter__(self):
yield from self._store
for key in self._side_dict:
if key not in self._store:
yield key
In this case, the generator has the advantage of not building the self._store.keys() | self._side_dict.keys()
set.
Also, unless you're writing this thing as a learning exercise, you should probably just use collections.ChainMap
. It handles all of this already.