Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to explain the result of tf.map_fn?

Tags:

tensorflow

Look at the code:

import tensorflow as tf
import numpy as np

elems = tf.ones([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

The output is:

(array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64))

I can't understand the output, who can tell me?

update

elems is a tensor, so it should be unpacked along axis-0, and we will get [[1,1,1],[1,1,1]], and then map_fn pass [[1,1,1],[1,1,1]] into lambda x:(x,x,x),which means x=[[1,1,1],[1,1,1]], and I think the output of map_fn is

[[[1,1,1],[1,1,1]],
 [[1,1,1],[1,1,1]],
 [[1,1,1],[1,1,1]]]

The shape of output is [3,2,3] or a list of shape(2,3)

But in fact, the output is a list of tensor, the shape of each tensor is [1,2,3].

Or in other words:

import tensorflow as tf
import numpy as np

elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

Why the output is

(array([1, 2, 3], dtype=int64), 
 array([2, 4, 6], dtype=int64), 
 array([-1, -2, -3], dtype=int64))

rather than

(array([1, 2, -1], dtype=int64), 
 array([2, 4, -2], dtype=int64), 
 array([3, 6, -3], dtype=int64))

The two question is the same.

Update2

import tensorflow as tf
import numpy as np

elems = [tf.constant([1,2,3],dtype=tf.int64)]
alternates = tf.map_fn(lambda x: x, elems, dtype=tf.int64)
with tf.Session() as sess:
    print(sess.run(alternates))

elems is a list of tensor, so according to api, tf.constant([1,2,3],dtype=tf.int64) will be unpacked along axis-0, so map_fn will works as [x for x in [1,2,3]], but in fact it will raise a error.

ValueError: The two structures don't have the same nested structure. First struc
ture: <dtype: 'int64'>, second structure: [<tf.Tensor 'map/while/TensorArrayRead
V3:0' shape=() dtype=int64>].

What's wrong?

update3

import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

The output is

(array([1, 2, 3], dtype=int64), array([1, 2, 3], dtype=int64))

It seems that elems aren't unpacked, why?

import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: [x], elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

It will raise a error

TypeError: The two structures don't have the same sequence type. First structure
 has type <class 'tuple'>, while second structure has type <class 'list'>.

Who can tell me how tf.map_fn works?

like image 332
gaussclb Avatar asked Sep 07 '17 12:09

gaussclb


People also ask

How does map_Fn work in PyTorch?

Equivalent to TF's "map_fn" in PyTorch? map_fn allows you to perform an operation in parallel and collect the results. My use case is I’d like to be able to run several mini supervised learning problems in parallel. In each thread, I would take several gradient steps on the same base model, and return the outputs.

What does fn and elems mean?

fn: the function which be called. Its parameter is the each element in elems on axis = 0. elems: a tensor, the elements will be passed into fn on axis = 0.

What is the use of map () method in TensorFlow?

map () method of tf.data.Dataset used for transforming items in a dataset, refer below snippet for map () use. This code snippet is using TensorFlow2.0, if you are using earlier versions of TensorFlow than enable execution to run the code.


2 Answers

First,

elems = tf.ones([1,2,3],dtype=tf.int64)

elems is a 3-dimensional tensor with shape 1x2x3 full of ones, that is:

[[[1, 1, 1],
  [1, 1, 1]]]

Then,

alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))

alternates is a tuple of three tensors with the same shape as elems, each of which is built according to the given function. Since the function simply returns a tuple repeating its input three times, that means that the three tensors will be the same as elems. If the function were lambda x: (x, 2 * x, -x) then the first output tensor would be the same as elems, the second would be the double of elems and the third one the opposite.

In all these cases it is preferable to use regular operations instead of tf.map_fn; however, there may be cases where you have a function accepting tensors with N dimensions and you have a tensor with N + 1 that you want to have it applied to.

UPDATE:

I think you are thinking of tf.map_fn "the other way around", so to say. There is not a one-to-one correspondence between the number of elements or rows in the tensor and the number of outputs in the function; in fact, you could pass a function returning a tuple with as many elements as you want.

Taking your last example:

elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))

tf.map_fn first split elems in the first axis, that is into 1, 2 and 3, and applies the function to each of them, getting:

(1, 2, -1)
(2, 4, -2)
(3, 6, -3)

Note that, as I said, each of these tuples could have as many elements as you wanted. Now, the final output is produced concatenating the results in the same position; so you get:

[1, 2, 3]
[2, 4, 6]
[-1, -2, -3]

Again, if the function produced tuples with more elements you would get more output tensors.

UPDATE 2:

About your new example:

import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

The documentation says:

This method also allows multi-arity elems and output of fn. If elems is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of fn may match the structure of elems. That is, if elems is (t1, [t2, t3, [t4, t5]]), then an appropriate signature for fn is: fn = lambda (t1, [t2, t3, [t4, t5]]):.

Here elems is a tuple of two tensors with the same size in the first dimension, as needed. tf.map_fn takes one element of each input tensor at a time (so a tuple of two elements) and applies the given function to it, which should return the same structure that you passed in dtypes (a tuple of two elements, too); if you don't give a dtypes, then the expected output is the same as the input (again, a tuple of two elements, so in your case dtypes is optional). Anyway, it goes like this:

f((1, 1)) -> (1, 1)
f((2, 2)) -> (2, 2)
f((3, 3)) -> (3, 3)

These results are combined, concatenating all the corresponding elements in the structure; in this case, all the numbers in the first position produce the first output and all the numbers in the second positions produce the second output. The result is, finally, the requested structure (the two-element tuple) filled with these concatenations:

([1, 2, 3], [1, 2, 3])
like image 166
jdehesa Avatar answered Sep 28 '22 06:09

jdehesa


Your input elems have shape (1,2,3) and look like this:

[[[1, 1, 1],
 [1, 1, 1]]]

It's not a matrix containing values 1,2,3, because you create it with tf.ones() that makes a tensor filled with 1 with the shape you pass as parameter

Replying to the Update:

map_fn is applied to elems itself. According to tf.map_fn's documentation:

elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to fn.

From what I understand there, the function expects a tensor or a list of tensors and supposedly slices it and applies the function to each element. However, from the results it seems that if you pass in a tensor that's the element it applies the function to directly, so x has shape (1,2,3) when the lambda function is called. The function then creates a tuple with 3 copies of your (1,2,3) matrix (which is the array(...) in your output)

Restructuring the output line and adding indent to make it more clear, the output looks as follows:

( 
   array( # first copy of `x`
       [
           [
               [1, 1, 1],
               [1, 1, 1]
           ]
       ], dtype=int64
   ), 
   array( # second copy of `x`
       [
           [
               [1, 1, 1],
               [1, 1, 1]
           ]
       ], dtype=int64
   ), 
   array( # third copy of `x`
       [
           [
               [1, 1, 1],
               [1, 1, 1]
           ]
       ], dtype=int64
   ), 
) # end of the tuple

Update 2:

My suspicion is that you ran into a bug. If you define elems as a list, you have the error, but if you define it as a tuple with elems = (tf.constant([1,2,3],dtype=tf.int64)), the code works as expected. Different handling of tuples and lists is very suspicious... which is why I believe it's a bug. As @mrry pointed out, in my example with the tuple I missed a comma (and thus elems was the tensor itself and not a tuple containing the tensor).

like image 21
GPhilo Avatar answered Sep 28 '22 04:09

GPhilo