pythontensorflowkeras

Why do I run out of memory when training with a large dataset, but have no problems with a small dataset?


I'm trying to build a keypoint detection system using Keras. I've got a UNet like model, with a series of convolutions, batch normalization, and max pooling, followed by a symmetric series of up sampling, convolution, and batch normalization layers (and skip connections). When given 100 instances, I'm able to call model.fit() without a problem. However, if I leave the model the same but use 500 instances, Keras crashes with an OOM exception. Why does this happen, and is there anything I can do to fix it?

Here's (what I think is) the relevant part of the code where I call model.fit:

model = build_model(
    filters=50,
    filter_step=1,
    stages=5,
    stage_steps=1,
    initial_convolutions=0,
    stacks=1,
)

print(model.summary()) 

dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.batch(1)

model.fit(
    dataset,
    epochs=2**7,
    callbacks=[
        EarlyStopping(monitor="loss", patience=5, min_delta=1e-7, start_from_epoch=10),
        LearningRateScheduler(step_decay)
    ],
)

X and y are Numpy arrays with the following shapes:

100 here is the data set size. If I increase this to 500 (or more), I get the out-of-memory exception. It appears to me that Keras is perhaps trying to load the entire data set into memory, despite using from_tensor_slices and batch(1), so I'm clearly misunderstanding something.


Solution

  • When you use tf.data.Dataset.from_tensor_slices((X, y)), TensorFlow attempts to create a dataset where each element is a pair (X[i], y[i]). If the dataset is too large, this can consume a significant amount of memory, especially if X and y are large.

    To address this memory issue, we can modify the data loading process using a generator to load the data in batches, during runtime.

    You'll have to define a generator that yields batches of data (X_batch, y_batch) and then to create the dataset use:

    tf.data.Dataset.from_generator

    Full documentation here.

    And an example could look like this:

    import numpy as np
    import tensorflow as tf
    
    # Assume X and y are your data
    X = np.random.rand(500, 1024, 1024, 3)
    y = np.random.rand(500, 1024, 1024)
    
    # Define a generator to yield batches of data
    def data_generator(X, y, batch_size):
        num_samples = X.shape[0]
        for i in range(0, num_samples, batch_size):
            yield X[i:i+batch_size], y[i:i+batch_size]
    
    # Parameters
    batch_size = 16
    
    # Create a generator
    generator = data_generator(X, y, batch_size)
    
    # Create a tf.data.Dataset using the generator
    dataset = tf.data.Dataset.from_generator(
        lambda: generator,
        output_signature=(
            tf.TensorSpec(shape=(None, 1024, 1024, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 1024, 1024), dtype=tf.float32)
        )
    )
    
    # Model and training code would go here...