Is it possible to get the following minimal example working with experimental_compile=True
? I've seen some big speedups with this argument hence I am keen to figure out how to get it working. Thanks!
import tensorflow as tf
print(tf.__version__)
# ===> 2.2.0-dev20200409
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
for i, tensor in enumerate(ragged_tensor):
print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
# [5. 6. 7. 8. 9.]]
# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]
# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
# [20. 21. 22. 23. 24.]]
@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
while_loop_fail()
# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
# [[{{node while/RaggedGetItem/strided_slice_4}}]]
# [[while]]
# This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]
XLA: Optimizing Compiler for TensorFlow. XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that accelerates TensorFlow models with potentially no source code changes. The results are improvements in speed and memory usage: most internal benchmarks run ~1.15x faster after XLA is enabled.
The slice size is still [8,6]. The gather operation in XLA generalizes the informal semantics outlined above in the following ways: We can configure which dimensions in the output shape are the offset dimensions (dimensions containing O 0, O 1 in the last example).
The XLA gather operation stitches together several slices (each slice at a potentially different runtime offset) of an input array. See also XlaBuilder::Gather . For a more intuitive description, see the "Informal Description" section below. The array we’re gathering from. Array containing the starting indices of the slices we gather.
XLA provides an alternative mode of running models: it compiles the TensorFlow graph into a sequence of computation kernels generated specifically for the given model. Because these kernels are unique to the model, they can exploit model-specific information for optimization.
There seems to be a lot of limitations about what XLA can do with ragged tensors. There are a couple of alternatives I can think of that could make your example work, but I don't know if they will we applicable to your real use case. On the one hand, you could sum over the ragged dimension(s) in advance, or even over all dimensions except the first one in your case. This however would need to be done outside of XLA, as it does not seem to be able to compile it:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Use the sums computed before
return i + 1, running_total + ragged_sum[i]
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0
You can also just convert the ragged tensor into a regular tensor, which will pad it with zeros that wouldn't affect your sum. Again, this would currently need to be done out of XLA:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Reduce padded tensor
return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With