pythontensorflowkerastensorflow-xlaxla

Why does tensorflow.function (without jit_compile) speed up forward passes of a Keras model?


XLA can be enabled using model = tf.function(model, jit_compile=True). Some model types are faster that way, some are slower. So far, so good.

But why can model = tf.function(model, jit_compile=None) speed things up significantly (without TPU) in some cases?

The jit_compile docs state:

If None (default), compiles the function with XLA when running on TPU and goes through the regular function execution path when running on other devices.

I'm running my tests on two non-TPU (and even non-GPU) machines (with the latest TensorFlow (2.13.0) installed).

import timeit

import numpy as np
import tensorflow as tf

model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)


def run(model):
    model(np.random.random(size=(1, 384, 384, 3)))


# warmup
run(model_plain)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)

runs = 10
duration_plain = timeit.timeit(lambda: run(model_plain), number=runs) / runs
duration_jit_compile_true = timeit.timeit(lambda: run(model_jit_compile_true), number=runs) / runs
duration_jit_compile_false = timeit.timeit(lambda: run(model_jit_compile_false), number=runs) / runs
duration_jit_compile_none = timeit.timeit(lambda: run(model_jit_compile_none), number=runs) / runs

print(f"{duration_plain=}")
print(f"{duration_jit_compile_true=}")
print(f"{duration_jit_compile_false=}")
print(f"{duration_jit_compile_none=}")
duration_plain=0.53095479644835
duration_jit_compile_true=1.5860380740836262
duration_jit_compile_false=0.09831228516995907
duration_jit_compile_none=0.09407951850444078

Solution

  • But why can model = tf.function(model, jit_compile=None) speed things up significantly (without TPU) in some cases?

    The speedup is mainly due to the graph mode enabled by tf.function, much faster than the eager execution used in model_plain.

    On top of that, we have secondary effects of XLA compilation with jit_compile flag, but they depend very much on the computing architecture. For instance, the numbers would look much different when compiled under the GPU accelerator.

    Last but not least, the benchmarking methodology should be corrected to take into account variation which is indeed huge for 10 runs and the use-case in question (otherwise, findings will be misleading or even contradictory, e.g. due to high variation XLA=None can look faster on average). For future reference, let's make it clear that this profiling pattern from Tensorflow docs is inaccurate

    # average runtime on 10 repetitions without variance is inaccurate
    print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
    

    The following corrected and extended snippet, executed on Kaggle notebooks with GPU, demonstrates that improvements come mostly from the graph mode and that XLA compilation gives some further speedup.

    import timeit
    
    import numpy as np
    import tensorflow as tf
    
    model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
    model_tffunc = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
    model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
    model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
    model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
    
    x = np.random.random(size=(1, 384, 384, 3))
    
    def run(model):
        model(x)
    
    # warmup
    run(model_plain)
    run(model_tffunc)
    run(model_jit_compile_true)
    run(model_jit_compile_false)
    run(model_jit_compile_none)
    
    # benchmarking
    duration_plain = %timeit -o run(model_plain)
    duration_tffunc = %timeit -o run(model_tffunc)
    duration_jit_compile_true = %timeit -o run(model_jit_compile_true)
    duration_jit_compile_false = %timeit -o run(model_jit_compile_false)
    duration_jit_compile_none = %timeit -o run(model_jit_compile_none)
    
    print(f"{str(duration_plain)=}")
    print(f"{str(duration_tffunc)=}")
    print(f"{str(duration_jit_compile_true)=}")
    print(f"{str(duration_jit_compile_false)=}")
    print(f"{str(duration_jit_compile_none)=}")
    

    Statistically, we have: duration_plain > duration_jit_compile_false = duration_jit_compile_none = duration_tffunc > duration_jit_compile_true, as seen from the output:

    369 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    16.1 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    11.6 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    15.9 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    15.5 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    str(duration_plain)='369 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)'
    str(duration_tffunc)='16.1 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)'
    str(duration_jit_compile_true)='11.6 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)'
    str(duration_jit_compile_false)='15.9 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)'
    str(duration_jit_compile_none)='15.5 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)'
    

    For a complete example, see this public notebook.

    NOTE: this way of measuring variation is useful but not fully accurate.