Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Can not save model using model.save following multi_gpu_model in Keras

Following the upgrade to Keras 2.0.9, I have been using the multi_gpu_model utility but I can't save my models or best weights using


The error I get is

TypeError: can’t pickle module objects

I suspect there is some problem gaining access to the model object. Is there a work around this issue?

like image 321
GhostRider Avatar asked Nov 09 '17 20:11


People also ask

How do you save a model on keras?

It is the default when you use model.save() . You can switch to the H5 format by: Passing save_format='h5' to save() . Passing a filename that ends in .h5 or .keras to save() .

How do I save a model h5 TensorFlow?

It is advised to use the save() method to save h5 models instead of save_weights() method for saving a model using tensorflow. However, h5 models can also be saved using save_weights() method. The location along with the weights name is passed as a parameter in this method.

Where does model save save the model?

Model will saved in the current directory and the model will overwrite the oldone when saving the same model with same name.

What does model save save?

The model. save() saves the whole architecture, weights and the optimizer state. This command saves the details needed to reconstitute your model.

What is a savedmodel in keras?

Let us take one example to understand the saving of the keras model by using the SavedModel format which is quite comprehensive as it helps in saving various components of the model including its weights, subgraphs of call functions, and architecture.

How to train a keras model with multiple variables?

Here's how it works: 1 Instantiate a MirroredStrategy, optionally configuring which specific devices you want to use (by default the strategy will use all GPUs available). 2 Use the strategy object to open a scope, and within this scope, create all the Keras objects you need that contain variables. ... 3 Train the model via fit () as usual.

What is multi-GPU&distributed training for keras?

Description: Guide to multi-GPU & distributed training for Keras models. There are generally two ways to distribute computation across multiple devices: Data parallelism, where a single model gets replicated on multiple devices or multiple machines. Each of them processes different batches of data, then they merge their results.

Why is my keras callback model being overwritten by multi_GPU?

I did a little digging in the keras github and it seems that when the call to fit_generator is made, the model in the callback is set to be the model making the call to fit_generator. So even if the correct model is set beforehand when creating the callback, this will be overwritten by the multi_gpu one.

Video Answer

2 Answers

To be honest, the easiest approach to this is to actually examine the multi gpu parallel model using


(The parallel model is simply the model after applying the multi_gpu function). This clearly highlights the actual model (in I think the penultimate layer - I am not at my computer right now). Then you can use the name of this layer to save the model.

 model = parallel_model.get_layer('sequential_1)

Often its called sequential_1 but if you are using a published architecture, it may be 'googlenet' or 'alexnet'. You will see the name of the layer from the summary.

Then its simple to just save


Maxims approach works, but its overkill I think.

Rem: you will need to compile both the model, and the parallel model.

like image 60
GhostRider Avatar answered Oct 07 '22 05:10



Here's a patched version that doesn't fail while saving:

from keras.layers import Lambda, concatenate
from keras import Model
import tensorflow as tf

def multi_gpu_model(model, gpus):
  if isinstance(gpus, (list, tuple)):
    num_gpus = len(gpus)
    target_gpu_ids = gpus
    num_gpus = gpus
    target_gpu_ids = range(num_gpus)

  def get_slice(data, i, parts):
    shape = tf.shape(data)
    batch_size = shape[:1]
    input_shape = shape[1:]
    step = batch_size // parts
    if i == num_gpus - 1:
      size = batch_size - step * i
      size = step
    size = tf.concat([size, input_shape], axis=0)
    stride = tf.concat([step, input_shape * 0], axis=0)
    start = stride * i
    return tf.slice(data, start, size)

  all_outputs = []
  for i in range(len(model.outputs)):

  # Place a copy of the model on each GPU,
  # each getting a slice of the inputs.
  for i, gpu_id in enumerate(target_gpu_ids):
    with tf.device('/gpu:%d' % gpu_id):
      with tf.name_scope('replica_%d' % gpu_id):
        inputs = []
        # Retrieve a slice of the input.
        for x in model.inputs:
          input_shape = tuple(x.get_shape().as_list())[1:]
          slice_i = Lambda(get_slice,
                           arguments={'i': i,
                                      'parts': num_gpus})(x)

        # Apply model on slice
        # (creating a model replica on the target device).
        outputs = model(inputs)
        if not isinstance(outputs, list):
          outputs = [outputs]

        # Save the outputs for merging back together later.
        for o in range(len(outputs)):

  # Merge outputs on CPU.
  with tf.device('/cpu:0'):
    merged = []
    for name, outputs in zip(model.output_names, all_outputs):
                                axis=0, name=name))
    return Model(model.inputs, merged)

You can use this multi_gpu_model function, until the bug is fixed in keras. Also, when loading the model, it's important to provide the tensorflow module object:

model = load_model('multi_gpu_model.h5', {'tf': tf})

How it works

The problem is with import tensorflow line in the middle of multi_gpu_model:

def multi_gpu_model(model, gpus):
  import tensorflow as tf

This creates a closure for the get_slice lambda function, which includes the number of gpus (that's ok) and tensorflow module (not ok). Model save tries to serialize all layers, including the ones that call get_slice and fails exactly because tf is in the closure.

The solution is to move import out of multi_gpu_model, so that tf becomes a global object, though still needed for get_slice to work. This fixes the problem of saving, but in loading one has to provide tf explicitly.

like image 28
Maxim Avatar answered Oct 07 '22 06:10
