pythonmachine-learningdeep-learning

How to Implement Softmax, in python, whereby the input are signed 8 integers


I am trying to implement a softmax function that takes in signed int8 input and returns a signed int8 output array.

The current implementation I have going is this,

 import numpy as np

def softmax_int8(inputs):
    inputs = np.array(inputs, dtype=np.int8)
    
    x = inputs.astype(np.int32)
    x_max = np.max(x)
    x_shifted = x - x_max
    scale_factor = 2 ** 14 
    exp_limit = 16
    exp_x = np.clip(x_shifted + exp_limit, 0, None)
    exp_x = (1 << exp_x)
    sum_exp_x = np.sum(exp_x)

    if sum_exp_x == 0:
        sum_exp_x = 1

    softmax_probs = (exp_x * scale_factor) // sum_exp_x
    max_prob = np.max(softmax_probs)
    min_prob = np.min(softmax_probs)
    range_prob = max_prob - min_prob if max_prob != min_prob else 1

    scaled_probs = ((softmax_probs - min_prob) * 255) // range_prob - 128
    outputs = scaled_probs.astype(np.int8)

    return outputs

I test it using this input, Input = [101, 49, 6, -34, -75, -79, -38, 120, -55, 115]

but I get this output array([-128, -128, -128, -128, -128, -128, -128, 127, -128, -121],dtype=int8).

My expected output is array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=int8).

What am I doing wrong here and how can I fix it?


Solution

  • I think there are different mathematical definitions of softmax in different contexts.

    The major difference is the base number of the exponential. With base too high you are highly likely to get underflow and get a lot of -128. Besides there are also a biase that maps the result to [-128, 127] range, which is trival and less important

    It's highly likely that the library that you takes test cases from use a different definition than both of above.

    I did some testing with your test case and floating point definition of softmax with matplotlib, and the following expression gives a good fit:

    softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
    

    You can imagine that you probably need to do a >>7 to input bytes before doing 1<< 2-based exponential. To give completely identical result, surely you should dig into that library code, which I didn't have time to do.

    Below are validation codes:

    import numpy as np
    import matplotlib.pyplot as plt
    
    inarr = np.array([101, 49, 6, -34, -75, -79, -38, 120, -55, 115], dtype=np.int8).astype(np.double)
    expected_arr = np.array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=np.int8).astype(np.double)
    print(expected_arr)
    
    softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
    print(softmax_naive - expected_arr)
    plt.plot(inarr)
    plt.plot(expected_arr)
    plt.plot(softmax_naive)
    plt.show()
    

    validation of softmax