I'm trying to build a keypoint detection system using Keras. I've got a UNet like model, with a series of convolutions, batch normalization, and max pooling, followed by a symmetric series of up sampling, convolution, and batch normalization layers (and skip connections). When given 100 instances, I'm able to call model.fit()
without a problem. However, if I leave the model the same but use 500 instances, Keras crashes with an OOM exception. Why does this happen, and is there anything I can do to fix it?
Here's (what I think is) the relevant part of the code where I call model.fit
:
model = build_model(
filters=50,
filter_step=1,
stages=5,
stage_steps=1,
initial_convolutions=0,
stacks=1,
)
print(model.summary())
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.batch(1)
model.fit(
dataset,
epochs=2**7,
callbacks=[
EarlyStopping(monitor="loss", patience=5, min_delta=1e-7, start_from_epoch=10),
LearningRateScheduler(step_decay)
],
)
X
and y
are Numpy arrays with the following shapes:
X
: (100, 1024, 1024, 3)y
: (100, 1024, 1024)100 here is the data set size. If I increase this to 500 (or more), I get the out-of-memory exception. It appears to me that Keras is perhaps trying to load the entire data set into memory, despite using from_tensor_slices
and batch(1)
, so I'm clearly misunderstanding something.
When you use
tf.data.Dataset.from_tensor_slices((X, y))
, TensorFlow attempts to create a dataset where each element is a pair(X[i], y[i])
. If the dataset is too large, this can consume a significant amount of memory, especially if X and y are large.
To address this memory issue, we can modify the data loading process using a generator to load the data in batches, during runtime.
You'll have to define a generator that yields batches of data (X_batch, y_batch)
and then to create the dataset use:
tf.data.Dataset.from_generator
Full documentation here.
And an example could look like this:
import numpy as np
import tensorflow as tf
# Assume X and y are your data
X = np.random.rand(500, 1024, 1024, 3)
y = np.random.rand(500, 1024, 1024)
# Define a generator to yield batches of data
def data_generator(X, y, batch_size):
num_samples = X.shape[0]
for i in range(0, num_samples, batch_size):
yield X[i:i+batch_size], y[i:i+batch_size]
# Parameters
batch_size = 16
# Create a generator
generator = data_generator(X, y, batch_size)
# Create a tf.data.Dataset using the generator
dataset = tf.data.Dataset.from_generator(
lambda: generator,
output_signature=(
tf.TensorSpec(shape=(None, 1024, 1024, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 1024, 1024), dtype=tf.float32)
)
)
# Model and training code would go here...