pytorchvgg-netbatch-normalization

Minimum Batch Size vs Batch Normalisation


I have been using VGG16 with transfer learning to for training a CNN via PyTorch. The maximum batch size that I can use is 16 as I am constrained by VRAM (GTX 1070).

I have noticed that when training with batch normalisation, the performance of my model on the test set is very poor, almost random classification, whereas without batch normalisation, my model performs reasonably well.

In an ideal world I would like to increase my batch size to at least 32 and take advantage of batch normalisation. Part of my question is if anyone knows how to either go about managing the memory constraints of my GPU to achieve this goal ?

The other half of my question relates to batch normalisation. I was originally suspicious that there was an error during my pre-processing pipeline which was causing the error but after a lot of searching I could not find any. I have read in a couple of articles that batch normalisation can worsen performance in the case of a small batch size, I was wondering if anyone with more experience could tell me if what I am seeing appears to be logical ?


Solution

  • Batch normalization is designed to work best with larger batch sizes, which can help to improve its stability and performance. In general, using a smaller batch size with batch normalization can lead to more noisy estimates of the mean and variance, which can degrade the performance of the model.

    Few different strategy suggestions:

    1. To reduce the size of your model or modify the architecture to use fewer layers or smaller feature maps. ( You can find suggestion for this at various resources. )
    2. Or, To try using mixed precision training. This can help to reduce the memory usage by using lower-precision data types for certain calculations.
    3. Or you may consider to find a way to increase your available memory, such as upgrading your GPU or using a cloud-based computing service. With this you can use larger batch size.

    In summary, it is possible that the use of batch normalization is negatively impacting your model's performance given the small batch size that you are using.