I am trying to use a prioritized replay buffer for my dqn agent. The problem I encounter is following.
I have a world which has (40, 40, 1) state representation. When I try to add a transition into the buffer, it gives me :
RuntimeError: expand(torch.DoubleTensor{[40, 40, 1]}, size=[3]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (3)
The Prioritized Replay buffer code:
class PrioritizedReplayBuffer:
def __init__(self, state_size=3, action_size=1, buffer_size=10000, eps=1e-2, alpha=0.1, beta=0.1):
self.tree = SumTree(size=buffer_size)
# PER params
self.eps = eps
self.alpha = alpha
self.beta = beta
self.max_priority = eps
# transition: state, action, reward, next_state, done
self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
self.reward = torch.empty(buffer_size, dtype=torch.float)
self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.done = torch.empty(buffer_size, dtype=torch.int)
self.count = 0
self.real_size = 0
self.size = buffer_size
def add(self, transition):
state, action, reward, next_state, done = transition
# store transition index with maximum priority in sum tree
self.tree.add(self.max_priority, self.count)
# store transition in the buffer
self.state[self.count] = torch.as_tensor(state)
self.action[self.count] = torch.as_tensor(action)
self.reward[self.count] = torch.as_tensor(reward)
self.next_state[self.count] = torch.as_tensor(next_state)
self.done[self.count] = torch.as_tensor(done)
# update counters
self.count = (self.count + 1) % self.size
self.real_size = min(self.size, self.real_size + 1)
Any help would be appreciated. Thanks
The problem is solved by:
# transition: state, action, reward, next_state, done
self.state = torch.empty((buffer_size, 40, 40, 1), dtype=torch.float)
self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
self.reward = torch.empty(buffer_size, dtype=torch.float)
self.next_state = torch.empty((buffer_size, 40, 40, 1), dtype=torch.float)
self.done = torch.empty(buffer_size, dtype=torch.int)