Logo Questions Linux Laravel Mysql Ubuntu Git Menu

TypeError: __init__() got an unexpected keyword argument 'trainable'

I am trying to load a RNN model architecture trained in Keras using keras.models.model_from_json and I am getting the mentioned error

with open('model_architecture.json', 'r') as f:
    model = model_from_json(f.read(), custom_objects={'AttLayer':AttLayer})

# Load weights into the new model

Here is the custom layer I am using

class AttLayer(Layer):
    def __init__(self, attention_dim):
        self.init = initializers.get('normal')
        self.supports_masking = True
        self.attention_dim = attention_dim
        super(AttLayer, self).__init__()

    def build(self, input_shape):
        assert len(input_shape) == 3
        self.W = K.variable(self.init((input_shape[-1], self.attention_dim)))
        self.b = K.variable(self.init((self.attention_dim, )))
        self.u = K.variable(self.init((self.attention_dim, 1)))
        self.trainable_weights = [self.W, self.b, self.u]
        super(AttLayer, self).build(input_shape)

    def compute_mask(self, inputs, mask=None):
        return None

    def call(self, x, mask=None):
        # size of x :[batch_size, sel_len, attention_dim]
        # size of u :[batch_size, attention_dim]
        # uit = tanh(xW+b)
        uit = K.tanh(K.bias_add(K.dot(x, self.W), self.b))
        ait = K.dot(uit, self.u)
        ait = K.squeeze(ait, -1)

        ait = K.exp(ait)

        if mask is not None:
            # Cast the mask to floatX to avoid float64 upcasting in theano
            ait *= K.cast(mask, K.floatx())
        ait /= K.cast(K.sum(ait, axis=1, keepdims=True) + K.epsilon(), K.floatx())
        ait = K.expand_dims(ait)
        weighted_input = x * ait
        output = K.sum(weighted_input, axis=1)

        return output

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])

    def get_config(self):
        config = {'attention_dim': self.attention_dim}
        base_config = super(AttLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


File "scripts/Classifier.py", line 254, in test
    model = model_from_json(f.read(), custom_objects={'AttLayer':AttLayer})
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/models.py", line 345, in model_from_json
    return layer_module.deserialize(config, custom_objects=custom_objects)
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2489, in from_config
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2475, in process_layer
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/wrappers.py", line 100, in from_config
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2489, in from_config
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2475, in process_layer
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 141, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 1254, in from_config
    return cls(**config)
TypeError: __init__() got an unexpected keyword argument 'trainable'



I tried training and loading using different versions, but with no luck. Finally I removed 'trainable' and 'name' (key value pairs)from my custom layer detail in the model architecture file(model_architecture.json) and model seems to be loading without any error. But this looks like a fix and I have to do this every time I train the model.

like image 890
Biswadip Mandal Avatar asked Nov 01 '18 09:11

Biswadip Mandal

1 Answers

I think you missed a small detail in your layer definition. You layers' __init__ method should take keyword arguments (**kwargs) and you should pass these keyword arguments to the parent class __init__, like this:

class AttLayer(Layer):
    def __init__(self, attention_dim, **kwargs):
        self.init = initializers.get('normal')
        self.supports_masking = True
        self.attention_dim = attention_dim
        super(AttLayer, self).__init__(**kwargs)

This way any generic layer parameter will be correctly passed to the parent class, in your case, the trainable flag.

like image 72
Dr. Snoopy Avatar answered Nov 14 '22 23:11

Dr. Snoopy