pythontensorflowmachine-learningkeras

Clearing tf.data.Dataset from GPU memory


I'm running into an issue when implementing a training loop that uses a tf.data.Dataset as input to a Keras model. My dataset has an element spec of the following format:

({'data': TensorSpec(shape=(15000, 1), dtype=tf.float32), 'index': TensorSpec(shape=(2,), dtype=tf.int64)}, TensorSpec(shape=(1,), dtype=tf.int32))

So, basically, each sample is structured as tuple (x, y), in which x has the structure of a dict containing two tensors, one of data with shape (15000, 1), and the other an index of shape (2,) (the index is not used during training), and y is a single label.

The tf.data.Dataset is created using dataset = tf.data.Dataset.from_tensor_slices((X, y)), where X is a dict of two keys:

and y is a single array of shape (200k, 1)

My dataset has about 200k training samples (after running undersampling) and 200k validation samples.

Right after calling tf.data.Dataset.from_tensor_slices I noticed a spike in GPU memory usage, with about 16GB being occupied after creating the training tf.Dataset, and 16GB more after creating the validation tf.Dataset.

After creating of the tf.Dataset, I run a few operations (e.g. shuffle, batching, and prefetching), and call model.fit. My model has about 500k trainable parameters.

The issue I'm running into is after fitting the model. I need to run inference on some additional data, so I create a new tf.Dataset with this data, again using tf.Dataset.from_tensor_slices. However, I noticed the training and validation tf.Dataset still reside in GPU memory, which causes my script to break with an out of memory problem for the new tf.Dataset I want to run inference on.

I tried calling del on the two tf.Dataset, and subsequently calling gc.collect(), but I believe that will only clear RAM, not GPU memory. Also, I tried disabling some operations I apply, such as prefetch, and also playing with the batch size, but none of that worked. I also tried calling keras.backend.clear_session(), but it also did not work to clear GPU memory. I also tried importing cuda from numba, but due to my install I cannot use it to clear memory. Is there any way for me to clear the tf.data.Dataset from GPU memory?

Minimum reproducible example below

Setup

import numpy as np
import tensorflow as tf

from itertools import product

# Setting tensorflow memory growth for GPU
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

Create dummy data with similar size as my actual data (types are the same as the actual data):

train_index = np.array(list(product(np.arange(1000), np.arange(200)))).astype(np.int32)
train_data = np.random.rand(200000, 15000).astype(np.float32)
train_y = np.random.randint(0, 2, size=(200000, 1)).astype(np.int32)

val_index = np.array(list(product(np.arange(1000), np.arange(200)))).astype(np.int32)
val_data = np.random.rand(200000, 15000).astype(np.float32)
val_y = np.random.randint(0, 2, size=(200000, 1)).astype(np.int32)

This is the nvidia-smi output at this point: nvidia-smi before calling the first tf.data.Dataset

Creating the training tf.data.Dataset, with as batch size of 256

train_X = {'data': train_data, 'index':train_index}
train_dataset = tf.data.Dataset.from_tensor_slices((train_X, train_y))
train_dataset = train_dataset.batch(256)

This is the nvidia-smi output after the tf.data.Dataset creation: nvidia-smi after calling the first tf.data.Dataset

Creating the validation tf.data.Dataset, with as batch size of 256

val_X = {'data': val_data, 'index':val_index}
val_dataset = tf.data.Dataset.from_tensor_slices((val_X, val_y))
val_dataset = val_dataset.batch(256)

This is the nvidia-smi output after the second tf.data.Dataset creation: nvidia-smi after calling the second tf.data.Dataset

So GPU usage grows when creating each tf.data.Dataset. Since after running model.fit I need to create a new tf.data.Dataset of similar size, I end up running out of memory. Is there any way to clear this data from GPU memory?


Solution

  • The problem is due to cache that does not get cleared when needed, this is an open issue.

    the only way I found is to make a large data set to replace the old one in cache

    dataset = tf.data.Dataset.range(num_epochs // 8) #drop the cache every 8 epochs
    dataset = dataset.flat_map(lambda newcache1: create_dataset().repeat(8))
    Model.fit(newcache1=dataset, ...)