None of the similar questions worked. So please do not flag as dublicate
Pytorch BatchNorm2d expects an input in the format N C H W where
N = Batchsize
C = Channels
H = Height
W = Width
as they indicate in the docs: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
If we test this using a random Tensor we get an error:
import torch
n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width
torch.nn.BatchNorm2d(h)(torch.rand(n,c,h,w))
The following code "works", but has the input format "NHWC"
import torch
n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width
x = torch.rand(n,h,w,c)
x = torch.nn.BatchNorm2d(h)(x)
The thing here is, if you are changing the values of N
, C
, H
, or W
variables, you are actually not changing the internal memory format the PyTorch developers have set; that's just a variable name, i.e., if you provide input in (n,h,c,w)
as above, internally, N->N
, H->C
(H will be the number of channels, instead heights as you are thinking), C->H
, and W->W
.
Returning to the question, the number of channels in your input data should match the number of channels in nn.BatchNorm2d
.
In your case, number of channels you set is one, but BatchNorm is expecting 64 channels from the user. To fix this, you can follow these examples:
Example:
import torch
n, c, h, w = 32, 64, 64, 512
x = torch.rand(n,c,h,w)
x = torch.nn.BatchNorm2d(h)(x)
and
import torch
n, c, h, w = 32, 1, 64, 512
x = torch.rand(n,c,h,w)
x = torch.nn.BatchNorm2d(c)(x)
I hope this helps you. Thanks!