I am trying to build a Variational Autoencoder on cifar10 images with Keras. It works perfectly on mnist data. But with cifar10, my losses (reconstruction loss and KL loss) are NAN when I call the method fit as you can see here: NAN loss
Here is my VAE with a custom training step:
Note: cifar10 images shape = (32, 32, 3), latent dimension = 2
class VAE(Model):
def __init__(self, encoder, decoder, **kwargs):
super().__init__(**kwargs)
# encoder and decoder
self.encoder = encoder
self.decoder = decoder
# losses metrics
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
# see 4. Encoder
z_mu, z_sigma, z = self.encoder(data)
z_decoded = self.decoder(z)
# compute the losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, z_decoded), axis=(1, 2)
)
)
kl_loss = -(1 + z_sigma - z_mu**2 - tf.exp(z_sigma)) / 2
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# gradients
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# update losses
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
# return the final losses
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
My encoder: encoder graph
My decoder: decoder graph
Does anyone have an idea?
In case this helps someone, I faced the exact problem and what fixed it for me was sticking to binary_crossentropy but making sure that the data was normalized, that is, all the image pixel values were between 0 and 1. So, something like this might help:
datagen = ImageDataGenerator(rescale=1./255, <anything else you want>)
Keeping the numbers bounded between 0 and 1 is important because otherwise the numbers may iterate in a positive feedback loop.