Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to do point-wise categorical crossentropy loss in Keras?

I have a network that produces a 4D output tensor where the value at each position in spatial dimensions (~pixel) is to be interpreted as the class probabilities for that position. In other words, the output is (num_batches, height, width, num_classes). I have labels of the same size where the real class is coded as one-hot. I would like to calculate the categorical-crossentropy loss using this.

Problem #1: The K.softmax function expects a 2D tensor (num_batches, num_classes)

Problem #2: I'm not sure how the losses from each position should be combined. Is it correct to reshape the tensor to (num_batches * height * width, num_classes) and then calling K.categorical_crossentropy on that? Or rather, call K.categorical_crossentropy(num_batches, num_classes) height*width times and average the results?

like image 704
Alex I Avatar asked Mar 26 '17 19:03

Alex I

2 Answers

Found this issue to confirm my intuition.

In short : the softmax will take 2D or 3D inputs. If they are 3D keras will assume a shape like this (samples, timedimension, numclasses) and apply the softmax on the last one. For some weird reasons, it doesnt do that for 4D tensors.

Solution : reshape your output to a sequence of pixels

reshaped_output = Reshape((height*width, num_classes))(output_tensor)

Then apply your softmax

new_output = Activation('softmax')(reshaped_output) 

And then either you reshape your target tensors to 2D or you just reshape that last layer into (width, height, num_classes).

Otherwise, something I would try if I wasn't on my phone right now is to use a TimeDistributed(Activation('softmax')). But no idea if that would work... will try later

I hope this helps :-)

like image 110
Nassim Ben Avatar answered Sep 29 '22 11:09

Nassim Ben

Just flatten the output to a 2D tensor of size (num_batches, height * width * num_classes). You can do this with the Flatten layer. Ensure that your y is flattened the same way (normally calling y = y.reshape((num_batches, height * width * num_classes)) is enough).

For your second question, using categorical crossentropy over all width*height predictions is essentially the same as averaging the categorical crossentropy for each width*height predictions (by the definition of categorical crossentropy).

like image 21
oscfri Avatar answered Sep 29 '22 13:09
