pythontensorflowtensorflow2.0tensorflow-xla

XLA can't deduce compile time constant output shape for strided slice when using ragged tensor and while loop


Is it possible to get the following minimal example working with experimental_compile=True? I've seen some big speedups with this argument hence I am keen to figure out how to get it working. Thanks!

import tensorflow as tf

print(tf.__version__)
# ===> 2.2.0-dev20200409

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

for i, tensor in enumerate(ragged_tensor):
    print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
#  [5. 6. 7. 8. 9.]]

# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]

# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
#  [20. 21. 22. 23. 24.]]


@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


while_loop_fail()
# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
#    [[{{node while/RaggedGetItem/strided_slice_4}}]]
#    [[while]]
#   This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]

Solution

  • For anyone having this sort of issue, I just noticed that on TensorFlow 2.5 this works (replacing experimental_compile with jit_compile):

    import tensorflow as tf
    
    print(tf.__version__)
    # 2.5.0
    
    x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
    row_lengths = tf.constant([2, 1, 2])
    ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
    
    for i, tensor in enumerate(ragged_tensor):
        print(f"i: {i}\ntensor:\n{tensor}\n")
    # ==>
    # i: 0
    # tensor:
    # [[0. 1. 2. 3. 4.]
    #  [5. 6. 7. 8. 9.]]
    
    # i: 1
    # tensor:
    # [[10. 11. 12. 13. 14.]]
    
    # i: 2
    # tensor:
    # [[15. 16. 17. 18. 19.]
    #  [20. 21. 22. 23. 24.]]
    
    
    @tf.function(autograph=False, jit_compile=True)
    def while_loop_works():
    
        num_rows = ragged_tensor.nrows()
    
        def cond(i, _):
            return i < num_rows
    
        def body(i, running_total):
            return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])
    
        _, total = tf.while_loop(cond, body, [0, 0.0])
    
        return total
    
    
    while_loop_works()
    # 2021-06-28 13:18:19.253261: I tensorflow/compiler/jit/xla_compilation_cache.cc:337] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
    # <tf.Tensor: shape=(), dtype=float32, numpy=300.0>