Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow 2.0: Accessing a batch's tensors from a callback

I'm using Tensorflow 2.0 and trying to write a tf.keras.callbacks.Callback that reads both the inputs and outputs of my model for the batch.

I expected to be able to override on_batch_end and access model.inputs and model.outputs but they are not EagerTensor with a value that I could access. Is there anyway to access the actual tensors values that were involved in a batch?

This has many practical uses such as outputting these tensors to Tensorboard for debugging, or serializing them for other purposes. I am aware that I could just run the whole model again using model.predict but that would force me to run every input twice through the network (and I might also have non-deterministic data generator). Any idea on how to achieve this?

like image 331
francoisr Avatar asked Jul 11 '19 11:07

francoisr


People also ask

What is TensorFlow callback?

Tensorflow callbacks are functions or blocks of code which are executed during a specific instant while training a Deep Learning Model. We all are familiar with the Training process of any Deep Learning model.

What is TF Keras callbacks ModelCheckpoint?

ModelCheckpoint callback is used in conjunction with training using model. fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.


1 Answers

No, there is no way to access the actual values for input and output in a callback. That's not just part of the design goal of callbacks. Callbacks only have access to model, args to fit, the epoch number and some metrics values. As you found, model.input and model.output only points to the symbolic KerasTensors, not actual values.

To do what you want, you could take the input, stack it (maybe with RaggedTensor) with the output you care about, and then make it an extra output of your model. Then implement your functionality as a custom metric that only reads y_pred. Inside your metric, unstack the y_pred to get the input and output, and then visualize / serialize / etc. Metrics

Another way might be to implement a custom Layer that uses py_function to call a function back in python. This will be super slow during serious training but may be enough for use during diagnostic / debugging.

like image 177
Yaoshiang Avatar answered Oct 24 '22 04:10

Yaoshiang