Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to return a dictionary of tensors from tf.py_function?

Usually, a transformers tokenizer encodes an input as a dictionary.

{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}

And to archive better performance handling with a large dataset, it is a good practice to implement a pipeline which includes using Dataset.map to apply a tokenizer function to each element of an input dataset. Exactly the same as done in the Tensorflow tutorial: Load text.

However, the tf.py_function (used to wrap the map python function) doesn't support returning a dictionary of tensors as shown above.

For instance, if the tokenizer (encoder) in the Load text returns the following dictionary:

{
    "input_ids": [ 101, 13366,  2131,  1035,  6819,  2094,  1035,  102 ],
    "attention_mask": [ 1, 1, 1, 1, 1, 1, 1, 1 ]
}

how can someone set the Tout parameter of the tf.py_function to get the desired dictionary of tensors:

{
    'input_ids': <tf.Tensor: shape=(16,), dtype=int32, numpy = array(
    [ 101, 13366,  2131,  1035,  6819,  2094,  1035,  102 ], dtype=int32)>

    'attention_mask': <tf.Tensor: shape=(16,), dtype=int32, numpy=array(
     [ 1, 1, 1, 1, 1, 1, 1, 1 ], dtype=int32)>
}

?

like image 700
Celso França Avatar asked Oct 24 '25 05:10

Celso França


1 Answers

tf.py_function doesn't allow python dict as return type. https://github.com/tensorflow/tensorflow/issues/36276

As a workaround in your case, you can do data transformation in your py_function and then call another tf.map without using py_function to return a dictionary.

def gen():
  yield 1

def process_data(x):
  return ([ 101, 13366,  2131,  1035,  6819,  2094,  1035,  102 ],
          [ 1, 1, 1, 1, 1, 1, 1, 1 ])

def create_dict(input_ids, attention_mask):
  return {"input_ids": tf.convert_to_tensor(input_ids),
          "attention_mask": tf.convert_to_tensor(attention_mask)}

ds = (tf.data.Dataset
      .from_generator(gen, (tf.int32))
      .map(lambda x: tf.py_function(process_data, inp=[x], 
                                    Tout=(tf.int32, tf.int32)))
      .map(create_dict)
      .repeat())

for x in ds:
  print(x)
  break

Output:

{'input_ids': <tf.Tensor: shape=(8,), dtype=int32, numpy=
array([  101, 13366,  2131,  1035,  6819,  2094,  1035,   102],
      dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(8,), dtype=int32, numpy=array([1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)>}
like image 191
Mahendra Singh Meena Avatar answered Oct 26 '25 20:10

Mahendra Singh Meena