MellowMax is a softmax operator that can be used instead of Max in the context of Deep Q Learning. Using Mellow Max has been shown to remove the need for a target network. Link to paper: https://arxiv.org/abs/1612.05628
To estimate a target Q Value, you perform mellow max on the Q Values of the next state. The mellow max function looks like this:
where x is the tensor of Q values and w is a temperature parameter.
My implementation is:
def mellow_max(q_values):
q_values = tf.cast(q_values, tf.float64)
powers = tf.multiply(q_values, DEEP_MELLOW_TEMPERATURE_VALUE)
summation_values = tf.math.exp(powers)
summation = tf.math.reduce_sum(summation_values, axis=1)
val_for_log = tf.multiply(summation,(1/NUM_ACTIONS))
numerator = tf.math.log(val_for_log)
mellow_val = tf.math.divide(numerator, DEEP_MELLOW_TEMPERATURE_VALUE).numpy()
return mellow_val
My issue is that the third line in this function returns values of +inf when using a temperature value 'w' of 1000. I'm using a temperature value 'w' of 1,000 as that's what was shown to be optimal in the above paper when applying to the Atari Breakout testbed.
Any suggestions would be appreciated on how I can prevent that third line from interfering with the calculation. Maybe, getting the limit of the function as 'w' goes to 1,000 would work. Any suggestions on how I could do that in tensorflow?
you cannot compute mellowmax like this. Because the exp function will go overflow/underflow quickly when the w*x_i is large. Thus you have to do some smarter thing, for example:
Here the logsumexp part only have very negative value thus it solves the overflow issue.
We can notice there is a logsumexp term. we know LSE will become log(K), when W is very large. The K is the number of max value presented in the x_i. You can use this to manually verify your result a bit.
If you wish to use very small w <<1, you have to take care the underflow. In this case, you use similar technique. But first calculate the mean value, then do the logsumexp around the mean value instead of max value.
I was wrong, there is no underflow risk here.
Here is my example:
import torch
def mellowmax(a: torch.Tensor, w: float):
m = torch.max(a)
N = torch.Tensor([len(a),])
# since the a - m are all negative, we can directly compute lse
lse = torch.exp((a - m)*w).sum().log_()
return m + (lse - N.log_())/w
N = 10
a = torch.randn((N,), dtype=torch.float)*N
for n in range(-4,5):
w = 10**n
mwm = mellowmax(a, w)
print(mwm, a.max(), a.mean())
result is:
tensor([2.1293]) tensor(17.7385) tensor(2.1235)
tensor([2.1791]) tensor(17.7385) tensor(2.1235)
tensor([2.6696]) tensor(17.7385) tensor(2.1235)
tensor([6.6293]) tensor(17.7385) tensor(2.1235)
tensor([15.4587]) tensor(17.7385) tensor(2.1235)
tensor([17.5083]) tensor(17.7385) tensor(2.1235)
tensor([17.7155]) tensor(17.7385) tensor(2.1235)
tensor([17.7362]) tensor(17.7385) tensor(2.1235)
tensor([17.7383]) tensor(17.7385) tensor(2.1235)
we can see first the mellowmax is very close to mean, then it becomes very close to max as w increasing.
Please noticed the meaningful w usually within 10. So your w=100 could be result of other issue. nevertheless, it depends on your x, the naive way to compute mellowmax can still results in overflow quite often.