XLA can't deduce compile time constant output shape for strided slice when using ragged tensor and while loop

Is it possible to get the following minimal example working with experimental_compile=True? I've seen some big speedups with this argument hence I am keen to figure out how to get it working. Thanks!

import tensorflow as tf

# ===> 2.2.0-dev20200409

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

for i, tensor in enumerate(ragged_tensor):
    print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
#  [5. 6. 7. 8. 9.]]

# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]

# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
#  [20. 21. 22. 23. 24.]]

@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total

# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
#    [[{{node while/RaggedGetItem/strided_slice_4}}]]
#    [[while]]
#   This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]
1 Answers

There seems to be a lot of limitations about what XLA can do with ragged tensors. There are a couple of alternatives I can think of that could make your example work, but I don't know if they will we applicable to your real use case. On the one hand, you could sum over the ragged dimension(s) in advance, or even over all dimensions except the first one in your case. This however would need to be done outside of XLA, as it does not seem to be able to compile it:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Use the sums computed before
        return i + 1, running_total + ragged_sum[i]

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total

result = while_loop_works()
# 300.0

You can also just convert the ragged tensor into a regular tensor, which will pad it with zeros that wouldn't affect your sum. Again, this would currently need to be done out of XLA:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Reduce padded tensor
        return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total

result = while_loop_works()
# 300.0
