Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to explicitly broadcast a tensor to match another's shape in tensorflow?

Tags:

tensorflow

I have three tensors, A, B and C in tensorflow, A and B are both of shape (m, n, r), C is a binary tensor of shape (m, n, 1).

I want to select elements from either A or B based on the value of C. The obvious tool is tf.select, however that does not have broadcasting semantics, so I need to first explicitly broadcast C to the same shape as A and B.

This would be my first attempt at how to do this, but it doesn't like me mixing a tensor (tf.shape(A)[2]) into the shape list.

import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))

C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)

What's the correct approach here?

like image 485
wxs Avatar asked Dec 18 '15 18:12

wxs


People also ask

How do you access values in a tensor?

We can access the value of a tensor by using indexing and slicing. Indexing is used to access a single value in the tensor. slicing is used to access the sequence of values in a tensor. we can modify a tensor by using the assignment operator.

What is a ragged tensor?

Ragged tensors are the TensorFlow equivalent of nested variable-length lists. They make it easy to store and process data with non-uniform shapes, including: Variable-length features, such as the set of actors in a movie. Batches of variable-length sequential inputs, such as sentences or video clips.

What is TensorShape?

A TensorShape represents a possibly-partial shape specification for a Tensor . It may be one of the following: Fully-known shape: has a known number of dimensions and a known size for each dimension.


2 Answers

EDIT: In all versions of TensorFlow since 0.12rc0, the code in the question works directly. TensorFlow will automatically stack tensors and Python numbers into a tensor argument. The solution below using tf.pack() is only needed in versions prior to 0.12rc0. Note that tf.pack() was renamed to tf.stack() in TensorFlow 1.0.


Your solution is very close to working. You should replace the line:

C = tf.tile(C, [1,1,tf.shape(C)[2]])

...with the following:

C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]]))

(The reason for the issue is that TensorFlow won't implicitly convert a list of tensors and Python literals into a tensor. tf.pack() takes a list of tensors, so it will convert each of the elements in its input (1, 1, and tf.shape(C)[2]) to a tensor. Since each element is a scalar, the result will be a vector.)

like image 51
mrry Avatar answered Oct 13 '22 01:10

mrry


Here's a dirty hack:

import tensorflow as tf

def broadcast(tensor, shape):
    return tensor + tf.zeros(shape, dtype=tensor.dtype)

A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])

C = broadcast(C, A.shape)
D = tf.select(C, A, B)
like image 21
Blaine Rogers Avatar answered Oct 13 '22 02:10

Blaine Rogers