Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Theano Dimshuffle equivalent in Google's TensorFlow?

I have seen that transpose and reshape together can help but I don't know how to use.

Eg. dimshuffle(0, 'x')

What is its equivalent by using transpose and reshape? or is there a better way? Thank you.

like image 382
Mihir Kavatkar Avatar asked Feb 02 '16 21:02

Mihir Kavatkar


2 Answers

There are three relevant ops for implementing Theano's dimshuffle in TensorFlow:

  • tf.transpose() is used to permute the dimensions of a tensor. If the pattern specified in the arguments to dimshuffle is a permutation of the input tensor's dimensions (i.e. there is no 'x' or missing dimension) you can use tf.transpose() to implement dimshuffle().

  • tf.expand_dims() is used to add one or more size-1 dimensions to a tensor. This handles the case where 'x' is specified as part of the dimshuffle() pattern, but does not reorder the existing dimensions.

  • tf.squeeze() is used to remove one or more size-1 dimensions from a tensor. This handles the case where a dimension is omitted from a dimshuffle() pattern, but it does not reorder the existing dimensions.

Assuming that the input is a vector, your example (dimshuffle(0, 'x')) can be expressed using tf.expand_dims() only:

input = tf.placeholder(tf.float32, [None])  # Defines an arbitrary-sized vector.
result = tf.expand_dims(input, 1)

print result.get_shape()  # ==> TensorShape([Dimension(None), Dimension(1)])

Taking a more complicated example, dimshuffle(1, 'x', 0) applied to a matrix would be:

input = tf.placeholder(tf.float32, [128, 32])  # Defines a matrix.
output = tf.expand_dims(tf.transpose(input, [1, 0]), 1)

print output.get_shape()
# ==> TensorShape([Dimension(32), Dimension(1), Dimension(128)])
like image 69
mrry Avatar answered Oct 26 '22 07:10

mrry


I implemented dimshuffle for TensorFlow in our framework Returnn (here). The code is this:

def expand_multiple_dims(x, axes, name="expand_multiple_dims"):
  """
  :param tf.Tensor x:
  :param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes
  :param str name: scope name
  :return: y where we have a new broadcast axis for each axis in axes
  :rtype: tf.Tensor
  """
  with tf.name_scope(name):
    for i in sorted(axes):
      x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i)
    return x


def dimshuffle(x, axes, name="dimshuffle"):
  """
  Like Theanos dimshuffle.
  Combines tf.transpose, tf.expand_dims and tf.squeeze.

  :param tf.Tensor x:
  :param list[int|str]|tuple[int|str] axes:
  :param str name: scope name
  :rtype: tf.Tensor
  """
  with tf.name_scope(name):
    assert all([i == "x" or isinstance(i, int) for i in axes])
    real_axes = [i for i in axes if isinstance(i, int)]
    bc_axes = [i for (i, j) in enumerate(axes) if j == "x"]
    if x.get_shape().ndims is None:
      x_shape = tf.shape(x)
      x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)])  # will have static ndims
    assert x.get_shape().ndims is not None

    # First squeeze missing axes.
    i = 0
    while i < x.get_shape().ndims:
      if i not in real_axes:
        x = tf.squeeze(x, axis=i)
        real_axes = [(j if (j < i) else (j - 1)) for j in real_axes]
      else:
        i += 1

    # Now permute.
    assert list(sorted(real_axes)) == list(range(x.get_shape().ndims))
    if real_axes != list(range(x.get_shape().ndims)):
      x = tf.transpose(x, real_axes)

    # Now add broadcast dimensions.
    if bc_axes:
      x = expand_multiple_dims(x, bc_axes)
    assert len(axes) == x.get_shape().ndims
    return x
like image 34
Albert Avatar answered Oct 26 '22 05:10

Albert