Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I compute element-wise conditionals on batches in TensorFlow?

Tags:

tensorflow

I basically have a batch of neuron activations of a layer in a tensor A of shape [batch_size, layer_size]. Let B = tf.square(A). Now I want to compute the following conditional on each element in each vector in this batch: if abs(e) < 1: e ← 0 else e ← B(e) where e is the element in B that is at the same position as e. Can I somehow vectorize the entire operation with a single tf.cond operation?

like image 699
Lenar Hoyt Avatar asked Jun 19 '16 21:06

Lenar Hoyt


1 Answers

You may want to look at tf.where(condition, x, y)

For your issue:

A = tf.placeholder(tf.float32, [batch_size, layer_size])
B = tf.square(A)

condition = tf.less(tf.abs(A), 1.)

res = tf.where(condition, tf.zeros_like(B), B)
like image 117
Olivier Moindrot Avatar answered Nov 10 '22 03:11

Olivier Moindrot