pythonreinforcement-learningpython-collections

TypeError: 'type' object is not iterable when iterating over collections.deque that contains collections.namedtuple


I made a simple replay buffer that when I sample from it gives me the error TypeError: 'type' object is not iterable

import collections
import numpy as np

Experience = collections.namedtuple("Experience", field_names=["state", "action", "reward", "done", "next_state"])

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def add_exp(self, exp: Experience):
        self.buffer.append(exp)

    def sample(self, batch_size):
        idxs = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in idxs])

        return np.array(states), np.array(actions), \
               np.array(rewards, dtype=np.float32), \
               np.array(dones, dtype=np.uint8), \
               np.array(next_states)

When I print the type of self.buffer[0] it gives 'type' but shouldn't it be ReplayBuffer.Experience?


Solution

  • You're adding a type to your list, not an instance of the type. What you're doing is essentially the same as this:

    class Experience:
        pass
    
    buffer = []
    
    buffer.append(Experience)
    

    Hopefully this makes it clearer what the problem is. You need to create an instance of Experience first, then add that instance to the list. Something like this:

    exp = Experience(the_state, the_action, the_reward, the_done, the_next_state)
    buff.add_exp(exp)
    

    Where all the the_ variables are the data that you want to instantiate the object with.


    Also note, the more modern way to write Experience is with class and NamedTuple:

    class Experience(NamedTuple):
        state: state_type
        action: action_type
        rewards: reward_type
        done: done_type
        next_state: state_type
    

    Where the _type are the types of each field. This allows type checkers to help you catch type errors.