Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to trigger a python function inside a tf.keras custom loss function?

Inside my custom loss function I need to call a pure python function passing in the computed TD errors and some indexes. The function doesn't need to return anything or be differentiated. Here's the function I want to call:

def update_priorities(self, traces_idxs, td_errors):
    """Updates the priorities of the traces with specified indexes."""
    self.priorities[traces_idxs] = td_errors + eps

I've tried using tf.py_function to call a wrapper function but it only gets called if it's embedded in the graph i.e. if it has inputs and outputs and the outputs are used. Therefore I tried to pass through some of the tensors without performing any operations on them and the function now gets called. Here's my entire custom loss function:

def masked_q_loss(data, y_pred):
    """Computes the MSE between the Q-values of the actions that were taken and the cumulative
    discounted rewards obtained after taking those actions. Updates trace priorities.
    """
    action_batch, target_qvals, traces_idxs = data[:,0], data[:,1], data[:,2]
    seq = tf.cast(tf.range(0, tf.shape(action_batch)[0]), tf.int32)
    action_idxs = tf.transpose(tf.stack([seq, tf.cast(action_batch, tf.int32)]))
    qvals = tf.gather_nd(y_pred, action_idxs)

    def update_priorities(_qvals, _target_qvals, _traces_idxs):
        """Computes the TD error and updates memory priorities."""
        td_error = _target_qvals - _qvals
        _traces_idxs = tf.cast(_traces_idxs, tf.int32)
        mem.update_priorities(_traces_idxs, td_error)
        return _qvals

    qvals = tf.py_function(func=update_priorities, inp=[qvals, target_qvals, traces_idxs], Tout=[tf.float32])
    return tf.keras.losses.mse(qvals, target_qvals)

However I get the following error due to the call mem.update_priorities(_traces_idxs, td_error)

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

I don't need to compute gradients for update_priorities, I just want to call it at a specific point in the graph computation and forget about it. How can I do that?

like image 892
Daniel Salvadori Avatar asked Jun 08 '19 15:06

Daniel Salvadori


People also ask

How do I use custom loss function in keras?

Creating custom loss functions in Keras A custom loss function can be created by defining a function that takes the true values and predicted values as required parameters. The function should return an array of losses. The function can then be passed at the compile stage.


1 Answers

Using .numpy() on the tensors inside the wrapper function fixed the problem:

def update_priorities(_qvals, _target_qvals, _traces_idxs):
    """Computes the TD error and updates memory priorities."""
    td_error = np.abs((_target_qvals - _qvals).numpy())
    _traces_idxs = (tf.cast(_traces_idxs, tf.int32)).numpy()
    mem.update_priorities(_traces_idxs, td_error)
    return _qvals
like image 135
Daniel Salvadori Avatar answered Oct 23 '22 05:10

Daniel Salvadori