I am trying to use the shift()
function on MNIST images.
Somehow, though, when I look at the original data and the shifted data, it looks like the shifted values that were exactly zero are becoming really small nonzero values instead of zero. An example of this would be that before shifting the value was zero and after shifting the value is something like ##########e-18
. And consequently all of the other values are becoming things like ##########e+02
.
Here's the code I'm running.
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
x, y = mnist['data'], mnist['target']
x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
import numpy as np
shuffle_index = np.random.permutation(60000)
x_train, y_train = x_train[shuffle_index], y_train[shuffle_index]
image = x_train[99]
reshaped = image.reshape(28,28)
reshaped_2 = reshaped.reshape(784,)
from scipy.ndimage.interpolation import shift
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0)[8:11,:])
And here's the output
[[ 0. 0. 0. 0. 0. 0. 0. 0. 32. 109. 109. 110. 109. 109.
109. 255. 253. 253. 253. 255. 211. 109. 47. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 32. 73. 73. 155. 217. 227. 252. 252. 253. 252. 252.
252. 253. 252. 252. 252. 253. 252. 252. 108. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 109. 252. 252. 252. 236. 226. 252. 231. 217. 215. 195.
71. 72. 71. 71. 154. 253. 252. 252. 108. 0. 0. 0. 0. 0.]]
[[-1.45736740e-17 2.08908499e-18 1.97425281e-17 1.32870826e-14
2.88143171e-14 2.90612090e-14 2.63726515e-14 2.89883698e-14
3.20000000e+01 1.09000000e+02 1.09000000e+02 1.10000000e+02
1.09000000e+02 1.09000000e+02 1.09000000e+02 2.55000000e+02
2.53000000e+02 2.53000000e+02 2.53000000e+02 2.55000000e+02
2.11000000e+02 1.09000000e+02 4.70000000e+01 8.06113136e-16
-1.58946559e-16 -9.39990682e-17 2.66688532e-17 -5.77791548e-17]
[-5.61019971e-16 2.32169340e-15 7.43877530e-15 3.20000000e+01
7.30000000e+01 7.30000000e+01 1.55000000e+02 2.17000000e+02
2.27000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 1.08000000e+02 3.29017268e-16
-6.57046610e-16 -1.22504799e-16 2.64344390e-17 -1.25480283e-16]
[-2.16877621e-15 7.92064171e-15 2.39544414e-14 1.09000000e+02
2.52000000e+02 2.52000000e+02 2.52000000e+02 2.36000000e+02
2.26000000e+02 2.52000000e+02 2.31000000e+02 2.17000000e+02
2.15000000e+02 1.95000000e+02 7.10000000e+01 7.20000000e+01
7.10000000e+01 7.10000000e+01 1.54000000e+02 2.53000000e+02
2.52000000e+02 2.52000000e+02 1.08000000e+02 3.04124747e-15
3.67217141e-17 -2.67076835e-16 -1.16801314e-16 -1.39584861e-16]]
What is causing this behavior? Is it a peculiarity of the MNIST dataset? Is it an error in my code?
The answer in Is it possible to use vector methods to shift images stored in a numpy ndarray for data augmentation? addresses how to do the shifting operation more efficiently, but it doesn't answer my other questions.
According to the shift
documentation (emphasis mine):
The array is shifted using spline interpolation of the requested order
with
order : int, optional
The order of the spline interpolation, default is 3. The order has to be in the range 0-5.
I will not pretend to know exactly how this interpolation is taking place, but it certainly seems that it affects the shifted values; so, I figured out that setting order=0
would disable this interpolation, and indeed it does. With the following changes in your code:
np.random.seed(42) # for reproducibility
# rest of your code as-is
print(reshaped[7:10,:])
print(shift(reshaped, [1,0], cval=0, order=0)[8:11,:]) # order=0
The results are indeed the same (no interpolation takes place during shifting):
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 168. 253.
200. 8. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 16. 235. 253.
80. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 65. 254. 169.
23. 0. 0. 0. 10. 14. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 168. 253.
200. 8. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 16. 235. 253.
80. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 65. 254. 169.
23. 0. 0. 0. 10. 14. 0. 0. 0. 0. 0. 0. 0. 0.]]
with
np.all(reshaped[7:10,:] == shift(reshaped, [1,0], cval=0, order=0)[8:11,:])
# True