I am trying to define a custom 8 bit floating point format as follows:
Is it possible to define this as a numpy datatype? If not, what is the easiest way to convert a numpy array of dtype float16 to such a format (for storage) and convert it back (for calculations in float16), maybe using the bit operations of numpy?
Why:
I am trying to optimize a neural network on custom hardware (FPGA). For this, I am playing around with various float representations. I have already built a forward pass framework for my neural network with numpy, therefore something like above will help me check the reduction in accuracy by storing the values in my custom datatype.
I'm by no means an expert in numpy, but I like to think about FP representation problems. The size of your array is not huge, so any reasonably efficient method should be fine. It doesn't look like there's an 8 bit FP representation, I guess because the precision isn't so good.
To convert to an array of bytes, each containing a single 8 bit FP value, for a single dimensional array, all you need is
float16 = np.array([6.3, 2.557]) # Here's some data in an array
float8s = array.tobytes()[1::2]
print(float8s)
>>> b'FAAF'
This just takes the high-order bytes from the 16 bit float by lopping off the low order part, giving a 1 bit sign, 5 bit exponent and 2 bit significand. The high order byte is always the second byte of each pair on a little-endian machine. I've tried it on a 2D array and it works the same. This truncates. Rounding in decimal would be a whole other can of worms.
Getting back to 16 bits would be just inserting zeros. I found this method by experiment and there is undoubtedly a better way, but this reads the byte array as 8 bit integers and writes a new one as 16 bit integers and then converts it back to an array of floats. Note the big-endian representation converting back to bytes as we want the 8 bit values to be the high order bytes of the integers.
float16 = np.frombuffer(np.array(np.frombuffer(float8s, dtype='u1'), dtype='>u2').tobytes(), dtype='f2')
print(float16)
>>> array([6. , 2.5, 2.5, 6. ], dtype=float16)
You can definitely see the loss of precision! I hope this helps. If this is sufficient, let me know. If not, I'd be up for looking deeper into it.