Inside my custom loss function I need to call a pure python function passing in the computed TD errors and some indexes. The function doesn't need to return anything or be differentiated. Here's the function I want to call:
def update_priorities(self, traces_idxs, td_errors):
"""Updates the priorities of the traces with specified indexes."""
self.priorities[traces_idxs] = td_errors + eps
I've tried using tf.py_function
to call a wrapper function but it only gets called if it's embedded in the graph i.e. if it has inputs and outputs and the outputs are used. Therefore I tried to pass through some of the tensors without performing any operations on them and the function now gets called. Here's my entire custom loss function:
def masked_q_loss(data, y_pred):
"""Computes the MSE between the Q-values of the actions that were taken and the cumulative
discounted rewards obtained after taking those actions. Updates trace priorities.
"""
action_batch, target_qvals, traces_idxs = data[:,0], data[:,1], data[:,2]
seq = tf.cast(tf.range(0, tf.shape(action_batch)[0]), tf.int32)
action_idxs = tf.transpose(tf.stack([seq, tf.cast(action_batch, tf.int32)]))
qvals = tf.gather_nd(y_pred, action_idxs)
def update_priorities(_qvals, _target_qvals, _traces_idxs):
"""Computes the TD error and updates memory priorities."""
td_error = _target_qvals - _qvals
_traces_idxs = tf.cast(_traces_idxs, tf.int32)
mem.update_priorities(_traces_idxs, td_error)
return _qvals
qvals = tf.py_function(func=update_priorities, inp=[qvals, target_qvals, traces_idxs], Tout=[tf.float32])
return tf.keras.losses.mse(qvals, target_qvals)
However I get the following error due to the call mem.update_priorities(_traces_idxs, td_error)
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
I don't need to compute gradients for update_priorities
, I just want to call it at a specific point in the graph computation and forget about it. How can I do that?
Creating custom loss functions in Keras A custom loss function can be created by defining a function that takes the true values and predicted values as required parameters. The function should return an array of losses. The function can then be passed at the compile stage.
Using .numpy()
on the tensors inside the wrapper function fixed the problem:
def update_priorities(_qvals, _target_qvals, _traces_idxs):
"""Computes the TD error and updates memory priorities."""
td_error = np.abs((_target_qvals - _qvals).numpy())
_traces_idxs = (tf.cast(_traces_idxs, tf.int32)).numpy()
mem.update_priorities(_traces_idxs, td_error)
return _qvals
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With