I'm running into an issue when implementing a training loop that uses a tf.data.Dataset
as input to a Keras model. My dataset has an element spec of the following format:
({'data': TensorSpec(shape=(15000, 1), dtype=tf.float32), 'index': TensorSpec(shape=(2,), dtype=tf.int64)}, TensorSpec(shape=(1,), dtype=tf.int32))
So, basically, each sample is structured as tuple (x, y)
, in which x
has the structure of a dict containing two tensors, one of data with shape (15000, 1)
, and the other an index of shape (2,)
(the index is not used during training), and y
is a single label.
The tf.data.Dataset
is created using dataset = tf.data.Dataset.from_tensor_slices((X, y))
, where X
is a dict of two keys:
data
: an np array of shape (200k, 1500, 1)
, index
withindex
: an np array of shape (200k, 2)
and y
is a single array of shape (200k, 1)
My dataset has about 200k training samples (after running undersampling) and 200k validation samples.
Right after calling tf.data.Dataset.from_tensor_slices
I noticed a spike in GPU memory usage, with about 16GB being occupied after creating the training tf.Dataset
, and 16GB more after creating the validation tf.Dataset
.
After creating of the tf.Dataset
, I run a few operations (e.g. shuffle, batching, and prefetching), and call model.fit
. My model has about 500k trainable parameters.
The issue I'm running into is after fitting the model. I need to run inference on some additional data, so I create a new tf.Dataset
with this data, again using tf.Dataset.from_tensor_slices
. However, I noticed the training and validation tf.Dataset
still reside in GPU memory, which causes my script to break with an out of memory problem for the new tf.Dataset
I want to run inference on.
I tried calling del
on the two tf.Dataset
, and subsequently calling gc.collect()
, but I believe that will only clear RAM, not GPU memory. Also, I tried disabling some operations I apply, such as prefetch
, and also playing with the batch size, but none of that worked. I also tried calling keras.backend.clear_session()
, but it also did not work to clear GPU memory. I also tried importing cuda
from numba
, but due to my install I cannot use it to clear memory. Is there any way for me to clear the tf.data.Dataset
from GPU memory?
Minimum reproducible example below
Setup
import numpy as np
import tensorflow as tf
from itertools import product
# Setting tensorflow memory growth for GPU
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
Create dummy data with similar size as my actual data (types are the same as the actual data):
train_index = np.array(list(product(np.arange(1000), np.arange(200)))).astype(np.int32)
train_data = np.random.rand(200000, 15000).astype(np.float32)
train_y = np.random.randint(0, 2, size=(200000, 1)).astype(np.int32)
val_index = np.array(list(product(np.arange(1000), np.arange(200)))).astype(np.int32)
val_data = np.random.rand(200000, 15000).astype(np.float32)
val_y = np.random.randint(0, 2, size=(200000, 1)).astype(np.int32)
This is the nvidia-smi output at this point:
Creating the training tf.data.Dataset
, with as batch size of 256
train_X = {'data': train_data, 'index':train_index}
train_dataset = tf.data.Dataset.from_tensor_slices((train_X, train_y))
train_dataset = train_dataset.batch(256)
This is the nvidia-smi output after the tf.data.Dataset
creation:
Creating the validation tf.data.Dataset
, with as batch size of 256
val_X = {'data': val_data, 'index':val_index}
val_dataset = tf.data.Dataset.from_tensor_slices((val_X, val_y))
val_dataset = val_dataset.batch(256)
This is the nvidia-smi output after the second tf.data.Dataset
creation:
So GPU usage grows when creating each tf.data.Dataset
. Since after running model.fit
I need to create a new tf.data.Dataset
of similar size, I end up running out of memory. Is there any way to clear this data from GPU memory?
The problem is due to cache that does not get cleared when needed, this is an open issue.
the only way I found is to make a large data set to replace the old one in cache
dataset = tf.data.Dataset.range(num_epochs // 8) #drop the cache every 8 epochs
dataset = dataset.flat_map(lambda newcache1: create_dataset().repeat(8))
Model.fit(newcache1=dataset, ...)