I have seen variations of this question asked, but I haven't quite found a satisfactory answer yet. Basically, I would like to do the equivalent from keras model.to_json()
, model.get_weights()
, model.from_json()
, model.set_weights()
to tensorflow. I think I am getting close to there, but I am at a point where I am stuck. I'd prefer if I could get the weights and graph in the same string, but I understand if that isn't possible.
Currently, what I have is:
g = optimizer.minimize(loss_op,
global_step=tf.train.get_global_step())
de = g.graph.as_graph_def()
json_string = json_format.MessageToJson(de)
gd = tf.GraphDef()
gd = json_format.Parse(json_string, gd)
That seems to create the graph fine, but obviously the meta graph is not included for variable, weights, etc. There is also the meta graph, but the only thing I see is export_meta_graph, which doesn't seem to serialize in the same manner. I saw that MetaGraph has a proto function, but I don't know how to serialize those variables.
So in short, how would you take a tensorflow model (model as in weights, graph, etc), serialize it to a string (preferably json), then deserialize it and continue training or serve predictions.
Here are things that get me close to there and I have tried, but mostly has limitations in needing to write to disk, which I can't do in this case:
Gist on GitHub
This is the closest one I found, but the link to serializing a metagraph doesn't exist.
Note that the solution from @Maxim will create new operations in the graph each time it runs.
If you run the function very frequently this will cause your code to get slower and slower.
Two solutions to work around this problem:
Create the assign operations at the same time as the rest of the graph and reuse them:
assign_ops = []
for var_name in tf.trainable_variables():
assign_placeholder = tf.placeholder(var.dtype, shape=value.shape)
assign_op = var.assign(assign_placeholder)
assign_ops.append(assign_op)
Use the load function on the variables, I prefer this one as it removes the need for the code above:
self.params = tf.trainable_variables()
def get_weights(self):
values = tf.get_default_session().run(self.params)
return values
def set_weights(self, weights):
for i, value in enumerate(weights):
value = np.asarray(value)
self.params[i].load(value, self.sess)
(I can't comment so I put this as an answer instead)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With