I am using Tensorflow 2.0 and facing the following situation:
@tf.function
def my_fn(items):
.... #do stuff
return
If items is a dict of Tensors like for example:
item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}
Is there a way of using input_signature argument of tf.function so I can force tf2 to avoid creating multiple graphs when item1 is for example tf.zeros([2,1])
?
The input signature has to be a list, but elements in the list can be dictionaries or lists of Tensor Specs. In your case I would try: (the name
attributes are optional)
signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
"item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") }
# don't forget the brackets around the 'signature_dict'
@tf.function(input_signature = [signature_dict])
def my_fn(items):
.... # do stuff
return
# calling the TensorFlow function
my_fun(items)
However, if you want to call a particular concrete function created by my_fn
, you have to unpack the dictionary. You also have to provide the name
attribute in tf.TensorSpec
.
# creating a concrete function with an input signature as before but without
# brackets and with mandatory 'name' attributes in the TensorSpecs
my_concrete_fn = my_fn.get_concrete_function(signature_dict)
# calling the concrete function with the unpacking operator
my_concrete_fn(**items)
This is annoying but should be resolved in TensorFlow 2.3. (see the end of the TF Guide to 'Concrete functions')
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With