tensorflowmachine-learningkerassupervised-learning

Passing pre-computed batches to Tensorflow fit() method


I would like to train a keras model (say a simple FFNN) using the model.fit() method and not doing it 'by hand' (i.e. by using the gradient.tape method explained for example here). However, the loss function I need to use is quite elaborated and cannot be computed on randomly generated batches of data. As a result, I need to train the model using batches of data computed 'by hand' (i.e. the data that goes into each batch needs to have certain properties and cannot be randomly assigned).

Can I pass somehow pre-computed batches to the fit() method?


Solution

  • One solution consists in sub-classing the Tensorflow Sequence. You can create your own batch for a given index using the __getitem__ method.

    class MySequence(tf.keras.utils.Sequence):
        def __init__(self, x_batch, y_batch) -> None:
          super().__init__()
          self.x_batch = x_batch   # ordered list of batches
          self.y_batch = y_batch   # idem
          self.leny = len(y_batch)
    
        def __len__(self):
          return self.leny
    
        def __getitem__(self, idx):
          x = self.x_batch[idx]
          y = self.y_batch[idx]
          return x, y
    

    You can pass of an instance of this Sequence sub-class to the Model fit method.
    Also set shuffle=False in the Model fit arguments.