Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

when restoring from a checkpoint, how can I change the data type of the parameters?

I have a pre-trained Tensorflow checkpoint, where the parameters are all of float32 data type.

How can I load checkpoint parameters as float16? Or is there a way to modify data types of a checkpoint?

Followings is my code snippet that tries to load float32 checkpoint into a float16 graph, and I got the type mismatch error.

import tensorflow as tf

A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dense))
    save_path = saver.save(sess, "tmp.ckpt")

tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    saver.restore(sess, "tmp.ckpt")
    print(sess.run(dense))
    pass

# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float
like image 828
dontloo Avatar asked Jun 12 '19 07:06

dontloo


1 Answers

Looking a bit into how savers work, seems you can redefine their construction through a builder object. You could for example have a builder that loads values as tf.float32 and then casts them to the actual type of the variable:

import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder

class CastFromFloat32SaverBuilder(BaseSaverBuilder):
  # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                   restore_sequentially):
    from tensorflow.python.ops import io_ops
    restore_specs = []
    for saveable in saveables:
      for spec in saveable.specs:
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    names, slices, dtypes = zip(*restore_specs)
    restore_dtypes = [tf.float32 for _ in dtypes]
    with tf.device("cpu:0"):
      restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
      return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]

Note this assumes that all restored variables are tf.float32. You can adapt the builder appropriately for your use case if necessary, e.g. passing the source type or types in the constructor, etc. With this, you just need to use the above builder in the second saver to get your example to work:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign)
    sess.run(tf.global_variables_initializer())
    print('Value to save:')
    print(sess.run(dense))
    save_path = saver.save(sess, "ckpt/tmp.ckpt")

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())
    saver.restore(sess, "ckpt/tmp.ckpt")
    print('Restored value:')
    print(sess.run(dense))

Output:

Value to save:
[[ 0.50589913  0.33701038 -0.11597633]
 [ 0.27372625  0.27724823  0.49825498]
 [ 1.0897961  -0.29577428 -0.9173869 ]]
Restored value:
[[ 0.506    0.337   -0.11597]
 [ 0.2737   0.2773   0.4983 ]
 [ 1.09    -0.296   -0.9175 ]]
like image 172
jdehesa Avatar answered Sep 18 '22 23:09

jdehesa