pytorchtensormini-batch

How to handle samples with multiple images in a pytorch image processing model?


My model training involves encoding multiple variants of a same image then summing the produced representation over all variants for the image.

The data loader produces tensor batches of the shape: [batch_size,num_variants,1,height,width]. The 1 corresponds to image color channels.

How can I train my model with minibatches in pytorch? I am looking for a proper way to forward all the batch_size×num_variant images through the network and summing the results over all groups of variants.

My current solution involves flattening the first two dimensions and doing a for loop to sum the representations, but I feel like there should be a better way an d I am not sure the gradients will remember everything.


Solution

  • Not sure I understood you correctly, but I guess this is what you want (say the batched image tensor is called image):

    Nb, Nv, inC, inH, inW = image.shape
    
    # treat each variant as if it's an ordinary image in the batch
    image = image.reshape(Nb*Nv, inC, inH, inW)
    
    output = model(image)
    _, outC, outH, outW = output.shape[1]
    
    # reshapes the output such that dim==1 indicates variants
    output = output.reshape(Nb, Nv, outC, outH, outW)
    
    # summing over the variants and lose the dimension of summation, [Nb, outC, outH, outW]
    output = output.sum(dim=1, keepdim=False)
    

    I used inC, outC, inH, etc. in case the input and output channels/sizes are different.