Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to conditionally assign values to tensor [masking for loss function]?

I want to create a L2 loss function that ignores values (=> pixels) where the label has the value 0. The tensor batch[1] contains the labels while output is a tensor for the net output, both have a shape of (None,300,300,1).

labels_mask = tf.identity(batch[1])
labels_mask[labels_mask > 0] = 1
loss = tf.reduce_sum(tf.square((output-batch[1])*labels_mask))/tf.reduce_sum(labels_mask)

My current code yields to TypeError: 'Tensor' object does not support item assignment (on the second line). What's the tensorflow-way to do this? I also tried to normalize the loss with tf.reduce_sum(labels_mask), which I hope works like this.

like image 932
ScientiaEtVeritas Avatar asked Jan 29 '18 22:01

ScientiaEtVeritas


2 Answers

Here is an example how to apply boolean indexing and conditionally assign values to Variable:

a = tf.Variable(initial_value=[0, 0, 4, 6, 1, 2, 4, 0])
mask = tf.greater_equal(a, 2)  # [False False  True  True False  True  True False]
indexes = tf.where(mask)  # [[2] [3] [5] [6]], shape=(4, 1)
b = tf.scatter_update(a, mask, tf.constant(1500))

output:

[   0,    0, 1500, 1500,    1, 1500, 1500,    0]
like image 163
Ivan Talalaev Avatar answered Oct 10 '22 23:10

Ivan Talalaev


If you wanted to write it that way, you would have to use Tensorflow's scatter method for assignment. Unfortunately, tensorflow doesn't really support boolean indexing either (the new boolean_select makes it possible, but annoying). It would be tricky to write and difficult to read.

You have two options that are less annoying:

  1. Use labels_mask > 0 as a boolean mask and use Tensorflow's recent boolean_mask function. Maybe this is the more tensorflow way, because it invokes arbitrarily specific functions.
  2. Cast labels_mask > 0 to float: tf.cast(labels_mask > 0, tf.float32). Then, you can use it the way you wanted to in the final line of your code.
like image 28
gngdb Avatar answered Oct 10 '22 23:10

gngdb