I have a pre-trained Tensorflow checkpoint, where the parameters are all of float32 data type.
How can I load checkpoint parameters as float16? Or is there a way to modify data types of a checkpoint?
Followings is my code snippet that tries to load float32 checkpoint into a float16 graph, and I got the type mismatch error.
import tensorflow as tf
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(dense))
save_path = saver.save(sess, "tmp.ckpt")
tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)
with tf.Session() as sess:
saver.restore(sess, "tmp.ckpt")
print(sess.run(dense))
pass
# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float
Looking a bit into how savers work, seems you can redefine their construction through a builder
object. You could for example have a builder that loads values as tf.float32
and then casts them to the actual type of the variable:
import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder
class CastFromFloat32SaverBuilder(BaseSaverBuilder):
# Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
from tensorflow.python.ops import io_ops
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
restore_dtypes = [tf.float32 for _ in dtypes]
with tf.device("cpu:0"):
restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
Note this assumes that all restored variables are tf.float32
. You can adapt the builder appropriately for your use case if necessary, e.g. passing the source type or types in the constructor, etc. With this, you just need to use the above builder in the second saver to get your example to work:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
assign = {vari.name: vari for vari in varis}
saver = tf.train.Saver(assign)
sess.run(tf.global_variables_initializer())
print('Value to save:')
print(sess.run(dense))
save_path = saver.save(sess, "ckpt/tmp.ckpt")
with tf.Graph().as_default(), tf.Session() as sess:
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
assign = {vari.name: vari for vari in varis}
saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())
saver.restore(sess, "ckpt/tmp.ckpt")
print('Restored value:')
print(sess.run(dense))
Output:
Value to save:
[[ 0.50589913 0.33701038 -0.11597633]
[ 0.27372625 0.27724823 0.49825498]
[ 1.0897961 -0.29577428 -0.9173869 ]]
Restored value:
[[ 0.506 0.337 -0.11597]
[ 0.2737 0.2773 0.4983 ]
[ 1.09 -0.296 -0.9175 ]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With