Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Loading weights in TH format when keras is set to TF format

I have Keras' image_dim_ordering property set to 'tf', so I define my models as this:

model = Sequential()
model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3)))
model.add(Convolution2D(64, 3, 3, activation='relu'))

But when I call load_weights method, it crashes because my model was saved using "th" format:

Exception: Layer weight shape (3, 3, 3, 64) not compatible with provided weight shape (64, 3, 3, 3)

How can I load these weights and automatically transpose them to fix Tensorflow's format?

like image 245
ldavid Avatar asked Sep 17 '16 13:09

ldavid


1 Answers

I asked Francois Chollet about this (he doesn't have an SO account) and he kindly passed along this reply:


"th" format means that the convolutional kernels will have the shape (depth, input_depth, rows, cols)

"tf" format means that the convolutional kernels will have the shape (rows, cols, input_depth, depth)

Therefore you can convert from the former to the later via np.transpose(x, (2, 3, 1, 0)) where x is the value of the convolution kernel.

Here's some code to do the conversion:

from keras import backend as K

K.set_image_dim_ordering('th')

# build model in TH mode, as th_model
th_model = ...
# load weights that were saved in TH mode into th_model
th_model.load_weights(...)

K.set_image_dim_ordering('tf')

# build model in TF mode, as tf_model
tf_model = ...

# transfer weights from th_model to tf_model
for th_layer, tf_layer in zip(th_model.layers, tf_model.layers):
   if th_layer.__class__.__name__ == 'Convolution2D':
      kernel, bias = layer.get_weights()
      kernel = np.transpose(kernel, (2, 3, 1, 0))
      tf_layer.set_weights([kernel, bias])
  else:
      tf_layer.set_weights(tf_layer.get_weights())

In case the model contains Dense layers downstream of the Convolution2D layers, then the weight matrix of the first Dense layer would need to be shuffled as well.

like image 180
Pete Warden Avatar answered Dec 04 '22 20:12

Pete Warden