Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Use dictionary in tf.function input_signature in Tensorflow 2.0

I am using Tensorflow 2.0 and facing the following situation:

@tf.function
def my_fn(items):
    .... #do stuff
    return

If items is a dict of Tensors like for example:

item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}

Is there a way of using input_signature argument of tf.function so I can force tf2 to avoid creating multiple graphs when item1 is for example tf.zeros([2,1]) ?

like image 243
m33n Avatar asked Mar 24 '20 09:03

m33n


1 Answers

The input signature has to be a list, but elements in the list can be dictionaries or lists of Tensor Specs. In your case I would try: (the name attributes are optional)

signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
                   "item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") } 
              

# don't forget the brackets around the 'signature_dict'
@tf.function(input_signature = [signature_dict])
def my_fn(items):
    .... # do stuff
    return

# calling the TensorFlow function
my_fun(items)

However, if you want to call a particular concrete function created by my_fn, you have to unpack the dictionary. You also have to provide the name attribute in tf.TensorSpec.

# creating a concrete function with an input signature as before but without
# brackets and with mandatory 'name' attributes in the TensorSpecs 
my_concrete_fn = my_fn.get_concrete_function(signature_dict)
                                             
# calling the concrete function with the unpacking operator
my_concrete_fn(**items)

This is annoying but should be resolved in TensorFlow 2.3. (see the end of the TF Guide to 'Concrete functions')

like image 172
Sven Meinhardt Avatar answered Sep 18 '22 14:09

Sven Meinhardt