My model training involves encoding multiple variants of a same image then summing the produced representation over all variants for the image.
The data loader produces tensor batches of the shape: [batch_size,num_variants,1,height,width]
.
The 1
corresponds to image color channels.
How can I train my model with minibatches in pytorch? I am looking for a proper way to forward all the batch_size×num_variant images through the network and summing the results over all groups of variants.
My current solution involves flattening the first two dimensions and doing a for loop to sum the representations, but I feel like there should be a better way an d I am not sure the gradients will remember everything.
Not sure I understood you correctly, but I guess this is what you want (say the batched image tensor is called image
):
Nb, Nv, inC, inH, inW = image.shape
# treat each variant as if it's an ordinary image in the batch
image = image.reshape(Nb*Nv, inC, inH, inW)
output = model(image)
_, outC, outH, outW = output.shape[1]
# reshapes the output such that dim==1 indicates variants
output = output.reshape(Nb, Nv, outC, outH, outW)
# summing over the variants and lose the dimension of summation, [Nb, outC, outH, outW]
output = output.sum(dim=1, keepdim=False)
I used inC
, outC
, inH
, etc. in case the input and output channels/sizes are different.