I wrote a TF data pipeline that looks something like this (TF 2.6):
def parse(img):
image = tf.image.decode_png(img, channels=3)
image = tf.reshape(image, IMG_SHAPE)
image = tf.cast(image, TARGET_DTYPE)
return image
def decode_batch(serialized_example, is_test=False):
feature_dict = {
'image': tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
}
if not is_test:
feature_dict["some_text"] = tf.io.FixedLenFeature(shape=[MAX_LEN], dtype=tf.int64, default_value=[0]*MAX_LEN)
else:
feature_dict["image_id"] = tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value='')
features = tf.io.parse_example(tf.reshape(serialized_example, [BATCH_SIZE_OVERALL]), features=feature_dict)
images = tf.map_fn(parse, features['image'], parallel_iterations=4, fn_output_signature=TARGET_DTYPE)
if is_test:
image_ids = features["image_id"]
return images, image_ids
else:
targets = tf.cast(features["some_text"], tf.uint8)
return images, targets
def get_dataset(filenames, is_test):
opts = tf.data.Options()
opts.experimental_deterministic = False
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.with_options(opts)
dataset = dataset.interleave(lambda x:
tf.data.TFRecordDataset(x),
cycle_length=4,
num_parallel_calls=4,
)
dataset = dataset.batch(BATCH_SIZE_OVERALL, num_parallel_calls=4, drop_remainder=True)
if not is_test:
dataset = dataset.repeat()
dataset = dataset.shuffle(BATCH_SIZE_OVERALL*6)
dataset = dataset.map(lambda y: decode_batch(y, is_test), num_parallel_calls=4)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
train_ds = get_dataset(TRAIN_TFREC_PATHS, False)
As you can see from the code, I did most of the tricks from the TF guide on correctly building the tf.data
pipeline. The problem I have is the following: when starting the training, the code does not use all 4 cores, but only 1 (sometimes more cores are used, but it seems to be caused by train_dist_ds.get_next()
call in the code below). Also, the GPU is almost not utilized at all. The profiler says that the problem is in preprocessing, and in tf_data_bottleneck_analysis
it indicates that the problem is in ParallelBatch
(although once he pointed to ParallelMap
, which seems true, but this does not say much by itself - cores are still underutilized anyway). Training function with a profiler looks like this:
def fit_profile(train_ds, val_ds, stop_after_steps):
tf.profiler.experimental.start('logdir')
stat_logger.current_step = 0
train_dist_ds = iter(train_ds)
while True:
stat_logger.batch_start_time = time.time()
stat_logger.current_step += 1
print(f'current step: {stat_logger.current_step}')
with tf.profiler.experimental.Trace('train', step_num=stat_logger.current_step, _r=1):
image_batch, some_text_batch = train_dist_ds.get_next()
train_step(image_batch, some_text_batch)
if stat_logger.current_step == stop_after_steps:
break
tf.profiler.experimental.stop()
As you can see, I don't touch the dataset, I don't put it into any strategy, it's in train_step
(which is of course wrapped into @tf.function
).
Questions: is there a way to somehow debug calculations inside the graph for tf.data
operations? In particular, at the level of calls to each tf.data
API function inside preprocessing -- so that I can understand what exactly to optimize. What could be the reason that only one core is used?
What I've tried so far:
tf.data.AUTOTUNE
- no effect;parallel_iterations
in map_fn
call - no effect;num_parallel_calls
- no effect to the point that it seems like it really doesn't matter.I finally found the reason for such behaviour. It was caused by using XLA with GPU.
I suddenly found this, and decided to turn off the XLA, and oh god, after almost a week of investigations, GPU was fully utilized and training times became waaay more sane (before that they were equal to CPU training times!!). As it's written in the article: 1) GPU support in XLA is experimental; 2) tensors need to have inferrable shapes; 3) all operations in the graph must be supported in XLA. Signs of such problems are poor CPU and GPU utilization, as well as bouncing training steps, i.e. one step takes 150 seconds, and the next 8-10 steps take one second each, and then this pattern is repeated. The article talks about TF 1.x, but it seems that not much has changed regarding this topic up till now (again, I'm using TF 2.6).
Main takeaways:
I will update this answer if I manage to meet these XLA requirements in my computations and turn on the XLA with the performance boost, not degradation.