Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the function in TensorFlow that is equivalent to expand() in PyTorch?

Let's say I have a 2 x 3 matrix and I want to create a 6 x 2 x 3 matrix where each element in the first dimension is the original 2 x 3 matrix.

In PyTorch, I can do this:

import torch
from torch.autograd import Variable
import numpy as np

x = np.array([[1, 2, 3], [4, 5, 6]])
x = Variable(torch.from_numpy(x))

# y is the desired result
y = x.unsqueeze(0).expand(6, 2, 3)

What is the equivalent way to do this in TensorFlow? I know unsqueeze() is equivalent to tf.expand_dims() but I don't TensorFlow has anything equivalent to expand(). I'm thinking of using tf.concat on a list of the 1 x 2 x 3 tensors but am not sure if this is the best way to do it.

like image 539
mauna Avatar asked Jan 12 '18 12:01

mauna


1 Answers

The equivalent function for pytorch expand is tensorflow tf.broadcast_to

Docs: https://www.tensorflow.org/api_docs/python/tf/broadcast_to

like image 176
funkyyyyyy Avatar answered Oct 30 '22 16:10

funkyyyyyy