Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras, Tensorflow: How to set breakpoint (debug) in custom layer when evaluating?

I just want to do some numerical validation inside the custom layer.

Suppose we have a very simple custom layer:

class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # Set break point here
        n = self.w * K.sqrt(x)
        return m + n

And the main program:

import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))

If I set a breakpoint debug at the line m = x * x, the program will pause here when executing y = test_layer()(input), this is because the graph is built, the call() method is called.

But when I use model.predict() to give it real value, and wanna look inside the layer if it work properly, it doesn't pause at the line m = x * x

My question is:

  1. Is call() method only called when the computational graph is being built? (it won't be called when feeding real value?)

  2. How to debug (or where to insert break point) inside a layer to see the value of variables when give it real value input?

like image 842
Edon Ron Avatar asked Aug 03 '18 12:08

Edon Ron


2 Answers

In TensorFlow 2, you can now add breakpoints to the TensorFlow Keras models/layers, including when using the fit, evaluate, and predict methods. However, you must add model.run_eagerly = True after calling model.compile() for the values of the tensor to be available in the debugger at the breakpoint. For example,

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


class SimpleModel(Model):

    def __init__(self):
        super().__init__()
        self.dense0 = Dense(2)
        self.dense1 = Dense(1)

    def call(self, inputs):
        z = self.dense0(inputs)
        z = self.dense1(z)  # Breakpoint in IDE here. =====
        return z

x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)

model0 = SimpleModel()
y0 = model0.call(x)  # Values of z shown at breakpoint. =====

model1 = SimpleModel()
model1.run_eagerly = True
model1.compile(optimizer=Adam(), loss=BinaryCrossentropy())
y1 = model1.predict(x)  # Values of z *not* shown at breakpoint. =====

model2 = SimpleModel()
model2.compile(optimizer=Adam(), loss=BinaryCrossentropy())
model2.run_eagerly = True
y2 = model2.predict(x)  # Values of z shown at breakpoint. =====

Note: this was tested in TensorFlow 2.0.0-rc0.

like image 102
golmschenk Avatar answered Oct 27 '22 00:10

golmschenk


  1. Yes. The call() method is only used to build the computational graph.

  2. As to the debug. I prefer using TFDBG, which is a recommended debugging tool for tensorflow, although it doesn't provide break point functions.

For Keras, you can add these line to your script to use TFDBG

import tf.keras.backend as K
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)
like image 20
Nathan Explosion Avatar answered Oct 27 '22 00:10

Nathan Explosion