Refer to # https://github.com/tensorflow/tensorflow/issues/32875
The suggested fix was to :
class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
@tf.function
def __call__(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred, axis=-1) # this is the fix
return super().__call__(y_true, y_pred, sample_weight=sample_weight)
It worked for TF2.1, but broke again in TF2.2. Is there a way to pass y_pred = tf.argmax(y_pred, axis=-1)
as y_pred
to this metric other than subclassing ?
This fixes the issue:
class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
def __init__(self,
y_true=None,
y_pred=None,
num_classes=None,
name=None,
dtype=None):
super(UpdatedMeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.math.argmax(y_pred, axis=-1)
return super().update_state(y_true, y_pred, sample_weight)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With