When operating in graph mode in TF1, I believe I needed to wire up training=True
and training=False
via feeddicts when I was using the functional-style API. What is the proper way to do this in TF2?
I believe this is automatically handled when using tf.keras.Sequential
. For example, I don't need to specify training
in the following example from the docs:
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10, activation='softmax')
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
Can I also assume that keras automagically handles this when training with the functional api? Here is the same model, rewritten using the function api:
inputs = tf.keras.Input(shape=((28,28,1)), name="input_image")
hid = tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1))(inputs)
hid = tf.keras.layers.MaxPooling2D()(hid)
hid = tf.keras.layers.Flatten()(hid)
hid = tf.keras.layers.Dropout(0.1)(hid)
hid = tf.keras.layers.Dense(64, activation='relu')(hid)
hid = tf.keras.layers.BatchNormalization()(hid)
outputs = tf.keras.layers.Dense(10, activation='softmax')(hid)
model_fn = tf.keras.Model(inputs=inputs, outputs=outputs)
# Model is the full model w/o custom layers
model_fn.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model_fn.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model_fn.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
I'm unsure if hid = tf.keras.layers.BatchNormalization()(hid)
needs to be hid = tf.keras.layers.BatchNormalization()(hid, training)
?
A colab for these models can be found here.
Keras is the high-level API of TensorFlow 2: an approachable, highly-productive interface for solving machine learning problems, with a focus on modern deep learning. It provides essential abstractions and building blocks for developing and shipping machine learning solutions with high iteration velocity.
Fit Keras Model You can train or fit your model on your loaded data by calling the fit() function on the model. Training occurs over epochs, and each epoch is split into batches.
As with most machine learning models, artificial neural networks built with the TensorFlow library are trained using the fit method. The fit method takes 4 parameters: The x values of the training data. The y values of the training data.
The Keras functional API is the way to go for defining complex models, such as multi-output models, directed acyclic graphs, or models with shared layers. This guide assumes that you are already familiar with the Sequential model.
I realized that there is a bug in the BatchNormalization
documentation [1] where the {{TRAINABLE_ATTRIBUTE_NOTE}}
isn't actually replaced with the intended note [2]:
About setting layer.trainable = False
on a BatchNormalization
layer:
The meaning of setting layer.trainable = False
is to freeze the layer,
i.e. its internal state will not change during training:
its trainable weights will not be updated
during fit()
or train_on_batch()
, and its state updates will not be run.
Usually, this does not necessarily mean that the layer is run in inference
mode (which is normally controlled by the training
argument that can
be passed when calling a layer). "Frozen state" and "inference mode"
are two separate concepts.
However, in the case of the BatchNormalization
layer, setting
trainable = False
on the layer means that the layer will be
subsequently run in inference mode (meaning that it will use
the moving mean and the moving variance to normalize the current batch,
rather than using the mean and variance of the current batch).
This behavior has been introduced in TensorFlow 2.0, in order
to enable layer.trainable = False
to produce the most commonly
expected behavior in the convnet fine-tuning use case.
Note that:
layer.trainable = False
would freeze the layer but would
not switch it to inference mode.trainable
on an model containing other layers will
recursively set the trainable
value of all inner layers.trainable
attribute is changed after calling compile()
on a model,
the new value doesn't take effect for this model
until compile()
is called again.[1] https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization?version=stable
[2] https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/keras/layers/normalization_v2.py#L26-L65
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With