I don't understand how the normalization in Pytorch
works.
I want to set the mean to 0
and the standard deviation to 1
across all columns in a tensor x
of shape (2, 2, 3)
.
A simple example:
>>> x = torch.tensor([[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[10., 11., 12.]]])
>>> norm = transforms.Normalize((0, 0), (1, 1))
>>> norm(x)
tensor([[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[10., 11., 12.]]])
So nothing has changed when applying the normalization transform. Why is that?
To give an answer to your question, you've now realized that torchvision.transforms.Normalize
doesn't work as you had anticipated. That's because it's not meant to:
normalize: (making your data range in [0, 1]
) nor
standardize: making your data's mean=0
and std=1
(which is what you're looking for.
The operation performed by T.Normalize
is merely a shift-scale transform:
output[channel] = (input[channel] - mean[channel]) / std[channel]
The parameters names mean
and std
which seems rather misleading knowing that it is not meant to refer to the desired output statistics but instead any arbitrary values. That's right, if you input mean=0
and std=1
, it will give you output = (input - 0) / 1 = input
. Hence the result you received where function norm
had no effect on your tensor values when you were expecting to get a tensor of mean and variance 0
and 1
, respectively.
However, providing the correct mean
and std
parameters, i.e. when mean=mean(data)
and std=std(data)
, then you end up calculating the z-score of your data channel by channel, which is what is usually called 'standardization'. So in order to actually get mean=0
and std=1
, you first need to compute the mean and standard deviation of your data.
If you do:
>>> mean, std = x.mean(), x.std()
(tensor(6.5000), tensor(3.6056))
It will give you the global average, and global standard deviation respectively.
Instead, what you want is to measure the 1st and 2nd order statistics per-channel. Therefore, we need to apply torch.mean
and torch.std
on all dimensions expect dim=1
. Both of those functions can receive a tuple of dimensions:
>>> mean, std = x.mean((0,2)), x.std((0,2))
(tensor([5., 8.]), tensor([3.4059, 3.4059]))
The above is the correct mean and standard deviation of x
measured along each channel. From there you can go ahead and use T.Normalize(mean, std)
to correctly transform your data x
with the correct shift-scale parameters.
>>> norm(x)
tensor([[[-1.5254, -1.2481, -0.9707],
[-0.6934, -0.4160, -0.1387]],
[[ 0.1387, 0.4160, 0.6934],
[ 0.9707, 1.2481, 1.5254]]])