Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

XLA can't deduce compile time constant output shape for strided slice when using ragged tensor and while loop

Is it possible to get the following minimal example working with experimental_compile=True? I've seen some big speedups with this argument hence I am keen to figure out how to get it working. Thanks!

import tensorflow as tf

print(tf.__version__)
# ===> 2.2.0-dev20200409

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

for i, tensor in enumerate(ragged_tensor):
    print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
#  [5. 6. 7. 8. 9.]]

# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]

# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
#  [20. 21. 22. 23. 24.]]


@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


while_loop_fail()
# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
#    [[{{node while/RaggedGetItem/strided_slice_4}}]]
#    [[while]]
#   This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]
like image 659
Jeff Avatar asked Apr 13 '20 13:04

Jeff


People also ask

What is XLA compiler in TensorFlow?

XLA: Optimizing Compiler for TensorFlow. XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that accelerates TensorFlow models with potentially no source code changes. The results are improvements in speed and memory usage: most internal benchmarks run ~1.15x faster after XLA is enabled.

What is the size of the slice in XLA?

The slice size is still [8,6]. The gather operation in XLA generalizes the informal semantics outlined above in the following ways: We can configure which dimensions in the output shape are the offset dimensions (dimensions containing O 0, O 1 in the last example).

How does the XLA gather operation work?

The XLA gather operation stitches together several slices (each slice at a potentially different runtime offset) of an input array. See also XlaBuilder::Gather . For a more intuitive description, see the "Informal Description" section below. The array we’re gathering from. Array containing the starting indices of the slices we gather.

What is the use of XLA in machine learning?

XLA provides an alternative mode of running models: it compiles the TensorFlow graph into a sequence of computation kernels generated specifically for the given model. Because these kernels are unique to the model, they can exploit model-specific information for optimization.


1 Answers

There seems to be a lot of limitations about what XLA can do with ragged tensors. There are a couple of alternatives I can think of that could make your example work, but I don't know if they will we applicable to your real use case. On the one hand, you could sum over the ragged dimension(s) in advance, or even over all dimensions except the first one in your case. This however would need to be done outside of XLA, as it does not seem to be able to compile it:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Use the sums computed before
        return i + 1, running_total + ragged_sum[i]

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


result = while_loop_works()
print(result.numpy())
# 300.0

You can also just convert the ragged tensor into a regular tensor, which will pad it with zeros that wouldn't affect your sum. Again, this would currently need to be done out of XLA:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Reduce padded tensor
        return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


result = while_loop_works()
print(result.numpy())
# 300.0
like image 136
jdehesa Avatar answered Oct 19 '22 12:10

jdehesa