I am trying to implement a softmax function that takes in signed int8 input and returns a signed int8 output array.
The current implementation I have going is this,
import numpy as np
def softmax_int8(inputs):
inputs = np.array(inputs, dtype=np.int8)
x = inputs.astype(np.int32)
x_max = np.max(x)
x_shifted = x - x_max
scale_factor = 2 ** 14
exp_limit = 16
exp_x = np.clip(x_shifted + exp_limit, 0, None)
exp_x = (1 << exp_x)
sum_exp_x = np.sum(exp_x)
if sum_exp_x == 0:
sum_exp_x = 1
softmax_probs = (exp_x * scale_factor) // sum_exp_x
max_prob = np.max(softmax_probs)
min_prob = np.min(softmax_probs)
range_prob = max_prob - min_prob if max_prob != min_prob else 1
scaled_probs = ((softmax_probs - min_prob) * 255) // range_prob - 128
outputs = scaled_probs.astype(np.int8)
return outputs
I test it using this input, Input = [101, 49, 6, -34, -75, -79, -38, 120, -55, 115]
but I get this output array([-128, -128, -128, -128, -128, -128, -128, 127, -128, -121],dtype=int8)
.
My expected output is array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=int8)
.
What am I doing wrong here and how can I fix it?
I think there are different mathematical definitions of softmax in different contexts.
exp(z) / sum(exp(z))
(1<<(z-z_max + 16)) / sum((1 << (z-z_max + 16)))
or something similar. 1<<
=== 2**
obviously.The major difference is the base number of the exponential. With base too high you are highly likely to get underflow and get a lot of -128
. Besides there are also a biase that maps the result to [-128, 127] range, which is trival and less important
It's highly likely that the library that you takes test cases from use a different definition than both of above.
I did some testing with your test case and floating point definition of softmax with matplotlib, and the following expression gives a good fit:
softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
You can imagine that you probably need to do a >>7
to input bytes before doing 1<<
2-based exponential. To give completely identical result, surely you should dig into that library code, which I didn't have time to do.
Below are validation codes:
import numpy as np
import matplotlib.pyplot as plt
inarr = np.array([101, 49, 6, -34, -75, -79, -38, 120, -55, 115], dtype=np.int8).astype(np.double)
expected_arr = np.array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=np.int8).astype(np.double)
print(expected_arr)
softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
print(softmax_naive - expected_arr)
plt.plot(inarr)
plt.plot(expected_arr)
plt.plot(softmax_naive)
plt.show()