Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dimensions mismatch error when using tf.metrics.MeanIoU() with SparseCategoricalCrossEntropy loss in Tensorflow 2.2

Refer to # https://github.com/tensorflow/tensorflow/issues/32875

The suggested fix was to :

class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
    @tf.function
    def __call__(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1) # this is the fix
        return super().__call__(y_true, y_pred, sample_weight=sample_weight)

It worked for TF2.1, but broke again in TF2.2. Is there a way to pass y_pred = tf.argmax(y_pred, axis=-1) as y_pred to this metric other than subclassing ?

like image 522
HuckleberryFinn Avatar asked Mar 02 '23 11:03

HuckleberryFinn


1 Answers

This fixes the issue:

class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
  def __init__(self,
               y_true=None,
               y_pred=None,
               num_classes=None,
               name=None,
               dtype=None):
    super(UpdatedMeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_pred = tf.math.argmax(y_pred, axis=-1)
    return super().update_state(y_true, y_pred, sample_weight)
like image 189
HuckleberryFinn Avatar answered Apr 07 '23 02:04

HuckleberryFinn