I'm implementing a semantic segmentation model with images. As a good practice I tested my training pipeline with just one image and tried to over-fit that image. To my surprise, when training with the exactly the same images, the loss goes to near 0 as expected but when evaluating THE SAME IMAGES, the loss is much much higher, and it keeps going up as the training continues. So the segmentation output is garbage when training=False
, but when run with training=True
is works perfectly.
To be able to anyone to reproduce this I took the official segmentation tutorial and modified it a little for training a convnet from scratch and just 1 image. The model is very simple, just a sequence of Conv2D with batch normalization and Relu. The results are the following
As you see the loss and eval_loss are really different, and making inference to the image gives perfect result in training mode and in eval mode is garbage.
I know Batchnormalization behaves differently in inference time since it uses the averaged statistics calculated whilst training. Nonetheless, since we are training with just 1 same image and evaluating in the same image, this shouldn't happen right? Moreover I implemented the same architecture with the same optimizer in Pytorch and this does not happen there. With pytorch it trains and eval_loss converges to train loss
Here you can find the above mentioned https://colab.research.google.com/drive/18LipgAmKVDA86n3ljFW8X0JThVEeFf0a#scrollTo=TWDATghoRczu and at the end also the Pytorch implementation
It had to do more with the defaults values that tensorflow uses. Batchnormalization has a parameter momentum
which controls the averaging of batch statistics. The formula is: moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
If you set momentum=0.0
in the BatchNorm layer, the averaged statistics should match perfectly with the statistics from the current batch (which is just 1 image). If you do so, you see that the validation loss almost immediately matches the training loss. Also if you try with momentum=0.9
(which is the equivalent default value in pytorch) and it works and converges faster (as in pytorch).