I have seen that transpose and reshape together can help but I don't know how to use.
Eg. dimshuffle(0, 'x')
What is its equivalent by using transpose and reshape? or is there a better way? Thank you.
There are three relevant ops for implementing Theano's dimshuffle
in TensorFlow:
tf.transpose()
is used to permute the dimensions of a tensor. If the pattern specified in the arguments to dimshuffle
is a permutation of the input tensor's dimensions (i.e. there is no 'x'
or missing dimension) you can use tf.transpose()
to implement dimshuffle()
.
tf.expand_dims()
is used to add one or more size-1 dimensions to a tensor. This handles the case where 'x'
is specified as part of the dimshuffle()
pattern, but does not reorder the existing dimensions.
tf.squeeze()
is used to remove one or more size-1 dimensions from a tensor. This handles the case where a dimension is omitted from a dimshuffle()
pattern, but it does not reorder the existing dimensions.
Assuming that the input is a vector, your example (dimshuffle(0, 'x')
) can be expressed using tf.expand_dims()
only:
input = tf.placeholder(tf.float32, [None]) # Defines an arbitrary-sized vector.
result = tf.expand_dims(input, 1)
print result.get_shape() # ==> TensorShape([Dimension(None), Dimension(1)])
Taking a more complicated example, dimshuffle(1, 'x', 0)
applied to a matrix would be:
input = tf.placeholder(tf.float32, [128, 32]) # Defines a matrix.
output = tf.expand_dims(tf.transpose(input, [1, 0]), 1)
print output.get_shape()
# ==> TensorShape([Dimension(32), Dimension(1), Dimension(128)])
I implemented dimshuffle
for TensorFlow in our framework Returnn (here). The code is this:
def expand_multiple_dims(x, axes, name="expand_multiple_dims"):
"""
:param tf.Tensor x:
:param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes
:param str name: scope name
:return: y where we have a new broadcast axis for each axis in axes
:rtype: tf.Tensor
"""
with tf.name_scope(name):
for i in sorted(axes):
x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i)
return x
def dimshuffle(x, axes, name="dimshuffle"):
"""
Like Theanos dimshuffle.
Combines tf.transpose, tf.expand_dims and tf.squeeze.
:param tf.Tensor x:
:param list[int|str]|tuple[int|str] axes:
:param str name: scope name
:rtype: tf.Tensor
"""
with tf.name_scope(name):
assert all([i == "x" or isinstance(i, int) for i in axes])
real_axes = [i for i in axes if isinstance(i, int)]
bc_axes = [i for (i, j) in enumerate(axes) if j == "x"]
if x.get_shape().ndims is None:
x_shape = tf.shape(x)
x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)]) # will have static ndims
assert x.get_shape().ndims is not None
# First squeeze missing axes.
i = 0
while i < x.get_shape().ndims:
if i not in real_axes:
x = tf.squeeze(x, axis=i)
real_axes = [(j if (j < i) else (j - 1)) for j in real_axes]
else:
i += 1
# Now permute.
assert list(sorted(real_axes)) == list(range(x.get_shape().ndims))
if real_axes != list(range(x.get_shape().ndims)):
x = tf.transpose(x, real_axes)
# Now add broadcast dimensions.
if bc_axes:
x = expand_multiple_dims(x, bc_axes)
assert len(axes) == x.get_shape().ndims
return x
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With