pythontensorflowmachine-learningdeep-learningreinforcement-learning

Difference between TensorFlow model fit and train_on_batch


I am building a vanilla DQN model to play the OpenAI gym Cartpole game.

However, in the training step where I feed in the state as input and the target Q values as the labels, if I use model.fit(x=states, y=target_q), it works fine and the agent can eventually play the game well, but if I use model.train_on_batch(x=states, y=target_q), the loss won't decrease and the model will not play the game anywhere better than a random policy.

I wonder what is the difference between fit and train_on_batch? To my understanding, fit calls train_on_batch with a batch size of 32 under the hood which should make no difference since specifying the batch size to equal the actual data size I feed in makes no difference.

The full code is here if more contextual information is needed to answer this question: https://github.com/ultronify/cartpole-tf


Solution

  • model.fit will train 1 or more epochs. That means it will train multiple batches. model.train_on_batch, as the name implies, trains only one batch.

    To give a concrete example, imagine you are training a model on 10 images. Let's say your batch size is 2. model.fit will train on all 10 images, so it will update the gradients 5 times. (You can specify multiple epochs, so it iterates over your dataset.) model.train_on_batch will perform one update of the gradients, as you only give the model on batch. You would give model.train_on_batch two images if your batch size is 2.

    And if we assume that model.fit calls model.train_on_batch under the hood (though I don't think it does), then model.train_on_batch would be called multiple times, likely in a loop. Here's pseudocode to explain.

    def fit(x, y, batch_size, epochs=1):
        for epoch in range(epochs):
            for batch_x, batch_y in batch(x, y, batch_size):
                model.train_on_batch(batch_x, batch_y)