Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow how to check if a tensor row is only zeroes?

I'm training a simple network for predicting bounding box coordinates of a single object. There are however pictures where there is no object to be found. Since the network always makes a prediction, it also predicts a confidence value between 0 and 1 which should indicate the probability that there is an object in the picture. My tensor with the predictions is called logits and it's a (batch_size, 5) tensor (confidence, x, y, width and height). Similarly the labels tensor is also (batch_size, 5).

Previously I was training only with images that always had an object, so I could basically do

loss = tf.l2_loss(logits - labels)

I want to start training also with pictures with no objects and when there is no object in the picture, I don't want the network to be penalized for whichever coordinates it predicted. In this case all that matters is the confidence value, which should be close to 0 (no object).

How should I structure my labels and loss function to accomplish this? I can set the label of images with no objects to all zeroes, but how do I check that a particular row is only zeroes? And in that case, the corresponding row in the logits needs also to be set to zeroes (except the confidence value!) so that the loss incurred because of the coordinates is also zero.

like image 526
Clash Avatar asked Sep 01 '25 02:09

Clash


1 Answers

You can use tf.math.count_nonzero() to check whether the tensor has all zeros or not. You can check the guide here.

like image 156
vijay m Avatar answered Sep 02 '25 17:09

vijay m