is there a way (code scripts) for merging tensorflow batchnorm and dropout layer into convolution layer in inference for faster computation?
I have searched a while, but not got relevant answers.
Batch Norm during Inference Here is where the two Moving Average parameters come in — the ones that we calculated during training and saved with the model. We use those saved mean and variance values for the Batch Norm during Inference.
Batch Normalization is a technique which takes care of normalizing the input of each layer to make the training process faster and more stable. In practice, it is an extra layer that we generally add after the computation layer and before the non-linearity.
Consequently, batch normalization adds two trainable parameters to each layer, so the normalized output is multiplied by a “standard deviation” parameter (gamma) and add a “mean” parameter (beta).
It means that during inference, the batch normalization acts as a simple linear transformation of what comes out of the previous layer, often a convolution. As a convolution is also a linear transformation, it also means that both operations can be merged into a single linear transformation!
To the best of my knowledge, there is no built-in feature in TensorFlow for folding batch normalization. That being said, it's not that hard to do it manually. One note, there is no such thing as folding dropout as dropout is simply deactivated at inference time.
To fold batch normalization there is basically three steps:
We need to filter the variables that require folding. When using batch normalization, it creates variables with names containing moving_mean
and moving_variance
. You can use this to extract fairly easily the variables from layers that used batch norm.
Now that you know which layers used batch norm, for every such layer, you can extract its weights W
, bias b
, batch norm variance v
, mean m
, gamma
and beta
parameters. You need to create a new variable to store the folded weights and biases as follow:
W_new = gamma * W / var
b_new = gamma * (b - mean) / var + beta
The last step consists in creating a new graph in which we deactivate batch norm and add bias
variables if necessary –which should be the case for every foldable layer since using bias with batch norm is pointless.
The whole code should look something like below. Depending on the parameters used for the batch norm, your graph may not have gamma
or beta
.
# ****** (1) Get variables ******
variables = {v.name: session.run(v) for v in tf.global_variables()}
# ****** (2) Fold variables ******
folded_variables = {}
for v in variables.keys():
if not v.endswith('moving_variance:0'):
continue
n = get_layer_name(v) # 'model/conv1/moving_variance:0' --> 'model/conv1'
W = variable[n + '/weights:0'] # or "/kernel:0", etc.
b = variable[n + '/bias:0'] # if a bias existed before
gamma = variable[n + '/gamma:0']
beta = variable[n + '/beta:0']
m = variable[n + '/moving_mean:0']
var = variable[n + '/moving_variance:0']
# folding batch norm
W_new = gamma * W / var
b_new = gamma * (b - mean) / var + beta # remove `b` if no bias
folded_variables[n + '/weights:0'] = W_new
folded_variables[n + '/bias:0'] = b_new
# ****** (3) Create new graph ******
new_graph = tf.Graph()
new_session = tf.Session(graph=new_graph)
network = ... # instance batch-norm free graph with bias added.
# Careful, the names should match the original model
for v in tf.global_variables():
try:
new_session.run(v.assign(folded_variables[v.name]))
except:
new_session.run(v.assign(variables[v.name]))
There is a tool provided by tensorflow that optimizes your trained frozen graph for inference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md#fold_batch_norms
build graph transform tool.
bazel build tensorflow/tools/graph_transforms:transform_graph
freeze your graph. e.g., https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
run this:
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul' \
--outputs='softmax' \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms'
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With