I'm using TensorFlow probability to train a model whose output is a tfp.distributions.Independent object for probabilistic regression. My problem is that I'm unsure how to implement sample weighting in the negative log likelihood (NLL) loss function.
I have the following loss function which I believe does not use the sample_weight third argument:
class NLL(tf.keras.losses.Loss):
''' Custom keras loss/metric for negative log likelihood '''
def __call__(self, y_true, y_pred, sample_weight=None):
return -y_pred.log_prob(y_true)
With standard TensorFlow loss functions and a dataset containing (X, y, sample_weight) tuples, the use of sample_weight in the loss reductions summations is handled under the hood. How can I make the sum in y_pred.log_prob use the weights in the sample_weight tensor?
I found a solution to my problem as posted in this GitHub issue.
My problem was caused by the fact that my model outputs a tfp.Independent distribution, which means the log_prob is returned as a scalar sum over individual log_probs for each element of the tensor. This prevents weighting individual elements of the loss function. You can get the underlying tensor of log_prob values by accessing the .distribution attribute of the tfp.Independent object - this underlying distribution object treats each element of the loss as an independent random variable, rather than a single random variable with multiple values. By writing a loss function that inherits from tf.keras.losses.Loss, the resulting weighted tensor is implicitly reduced, returning the weighted mean of log_prob values rather than the sum, e.g.:
class NLL(tf.keras.losses.Loss):
''' Custom keras loss/metric for weighted negative log likelihood '''
def __call__(self, y_true, y_pred, sample_weight=None):
# This tensor is implicitly reduced by TensorFlow
# by taking the mean over all weighted elements
return -y_pred.distribution.log_prob(y_true) * sample_weight
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With