I'm trying to train a simple CNN with Flux and running into a weird issue...during training the loss appears to go down (indicating that it's working) but despite what the loss curve suggested the "trained" model output was very bad, and when I calculated the loss by hand I noticed that it differed from what the training indicated it should be (it was acting like it hadn't been trained at all).
I then started calculating the loss returned inside the gradient vs. outside, and after a lot of digging I think the problem is related to the BatchNorm
layer. Consider the following minimum example:
using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image, some relationship to the input values (doesn't matter for this)
m = Chain(BatchNorm(1),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m) #loss calculated by gradient
l_final = Flux.mse(m(x),y) #loss calculated again using the model (no parameters have been updated)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")
All of the losses above will be different, sometimes pretty drastically (when running just now I got 22.6, 30.7, and 23.0), when I think they should all be the same?
Interestingly if I remove the BatchNorm
layer, the outputs are all the same, i.e. running:
using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image
m = Chain(Conv((1,1),1=>1))
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m)
l_final = Flux.mse(m(x),y)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")
Produces the same number for each loss calculation.
Why does including the BatchNorm
layer change the value of the loss like this?
My (limited) understanding was that this was just supposed to normalize the input values, which I understand could affect the loss between the unormalized and normalized case, but I don't understand why it would produce different values of the loss for the same input values on the same model without any of the parameters of said model being updated?
Look at the documentation of BatchNorm
BatchNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true, active=nothing,
eps=1f-5, momentum= 0.1f0)
Batch Normalization (https://arxiv.org/abs/1502.03167) layer. channels should
be the size of the channel dimension in your data (see below).
Given an array with N dimensions, call the N-1th the channel dimension. For a
batch of feature vectors this is just the data dimension, for WHCN images it's
the usual channel dimension.
BatchNorm computes the mean and variance for each D_1×...×D_{N-2}×1×D_N input
slice and normalises the input accordingly.
If affine=true, it also applies a shift and a rescale to the input through to
learnable per-channel bias β and scale γ parameters.
After normalisation, elementwise activation λ is applied.
If track_stats=true, accumulates mean and var statistics in training phase that
will be used to renormalize the input in test phase.
Use testmode! during inference.
Examples
≡≡≡≡≡≡≡≡≡≡
julia> using Statistics
julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels
julia> m = BatchNorm(3);
julia> Flux.trainmode!(m);
julia> isapprox(std(m(xs)), 1, atol=0.1) && std(xs) != std(m(xs))
true
The key bit here is that per default track_stats=true
. This leads to the changing inputs. If you don't want to have this behaviour, initialise your model with
m = Chain(BatchNorm(1, track_state=false),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)
and you'll get identical outputs as in your second example.
The BatchNorm
is initialised with zero mean and unit std, and your input data isn't, that's why you'll get the changing output even with repeated identical input in the case that track_state=true
, as far as I can see it (quickly).