XLA can be enabled using model = tf.function(model, jit_compile=True)
. Some model types are faster that way, some are slower. So far, so good.
But why can model = tf.function(model, jit_compile=None)
speed things up significantly (without TPU) in some cases?
The jit_compile
docs state:
If
None
(default), compiles the function with XLA when running on TPU and goes through the regular function execution path when running on other devices.
I'm running my tests on two non-TPU (and even non-GPU) machines (with the latest TensorFlow (2.13.0
) installed).
import timeit
import numpy as np
import tensorflow as tf
model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
def run(model):
model(np.random.random(size=(1, 384, 384, 3)))
# warmup
run(model_plain)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)
runs = 10
duration_plain = timeit.timeit(lambda: run(model_plain), number=runs) / runs
duration_jit_compile_true = timeit.timeit(lambda: run(model_jit_compile_true), number=runs) / runs
duration_jit_compile_false = timeit.timeit(lambda: run(model_jit_compile_false), number=runs) / runs
duration_jit_compile_none = timeit.timeit(lambda: run(model_jit_compile_none), number=runs) / runs
print(f"{duration_plain=}")
print(f"{duration_jit_compile_true=}")
print(f"{duration_jit_compile_false=}")
print(f"{duration_jit_compile_none=}")
duration_plain=0.53095479644835
duration_jit_compile_true=1.5860380740836262
duration_jit_compile_false=0.09831228516995907
duration_jit_compile_none=0.09407951850444078
But why can model = tf.function(model, jit_compile=None) speed things up significantly (without TPU) in some cases?
The speedup is mainly due to the graph mode
enabled by tf.function
, much faster than the eager execution used in model_plain
.
On top of that, we have secondary effects of XLA compilation with jit_compile
flag, but they depend very much on the computing architecture. For instance, the numbers would look much different when compiled under the GPU accelerator.
Last but not least, the benchmarking methodology should be corrected to take into account variation which is indeed huge for 10 runs and the use-case in question (otherwise, findings will be misleading or even contradictory, e.g. due to high variation XLA=None
can look faster on average).
For future reference, let's make it clear that this profiling pattern from Tensorflow docs is inaccurate
# average runtime on 10 repetitions without variance is inaccurate
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
The following corrected and extended snippet, executed on Kaggle notebooks with GPU, demonstrates that improvements come mostly from the graph mode and that XLA compilation gives some further speedup.
import timeit
import numpy as np
import tensorflow as tf
model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_tffunc = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
x = np.random.random(size=(1, 384, 384, 3))
def run(model):
model(x)
# warmup
run(model_plain)
run(model_tffunc)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)
# benchmarking
duration_plain = %timeit -o run(model_plain)
duration_tffunc = %timeit -o run(model_tffunc)
duration_jit_compile_true = %timeit -o run(model_jit_compile_true)
duration_jit_compile_false = %timeit -o run(model_jit_compile_false)
duration_jit_compile_none = %timeit -o run(model_jit_compile_none)
print(f"{str(duration_plain)=}")
print(f"{str(duration_tffunc)=}")
print(f"{str(duration_jit_compile_true)=}")
print(f"{str(duration_jit_compile_false)=}")
print(f"{str(duration_jit_compile_none)=}")
Statistically, we have: duration_plain > duration_jit_compile_false = duration_jit_compile_none = duration_tffunc > duration_jit_compile_true
, as seen from the output:
369 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
16.1 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
11.6 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
15.9 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
15.5 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
str(duration_plain)='369 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)'
str(duration_tffunc)='16.1 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)'
str(duration_jit_compile_true)='11.6 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)'
str(duration_jit_compile_false)='15.9 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)'
str(duration_jit_compile_none)='15.5 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)'
For a complete example, see this public notebook.
NOTE: this way of measuring variation is useful but not fully accurate.