I have trained a model using python-tensorflow and I would like to do the inference in java-tensorflow. I have loaded the trained model/graph into Java. After this, I would like to update one variable in the graph permanently . I am aware of tf.variable.load(value,session) function in python that can be used to update the value of variable. I was wondering if there is any similar method in Java.
I have tried the following things so far.
// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)
But the above line is only using updatedTensorValue for variableName during the fetch calls that are executed in the same line.
g.opBuilder("Assign",variableName).setAttr("value",updatedTensorValue).build();
Instead of updating the value, the above line is trying to add the same variable to the graph and hence, it is throwing an exception.
Another alternative to permanently updating the variable in the graph, I will always call the feed(variableName,updatedTensorValue) method during all the fetch calls . I would be running the inference code on several instances and so I was wondering about the additional time it would take for this additional feed call.
Thanks
The way to do most things in TensorFlow is to execute an operation. You were on the right track with trying to run the Assign operation, but were invoking it incorrectly, since the value to be assigned is not an "attribute" of the Assign operation but instead an input tensor. (See the raw definition of the operation, though admittedly that definition may not be easy to follow unless you're familiar with TensorFlow internals).
However, you do not need to add an operation to the graph in Java to do this. Instead, you could do exactly what tf.Variable.load in Python does - execute the tf.Variable.initializer operation, feeding in the input value.
For example, consider the following graph built in Python:
import tensorflow as tf
var = tf.Variable(1.0, name='myvar')
init = tf.global_variables_initializer()
# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables: ', init.name)
print('myvar.initializer: ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)
Now, we replicate the behavior of the Python var.load() in Java to assign a value of 3.0 to the variable using something like this:
try (Tensor<Float> newValue = Tensors.create(3.0f)) {
s.runner()
.feed("myvar/initial_value", newVal) // myvar.initializer.inputs[1].name
.addTarget("myvar/Assign") // myvar.initializer.name
.run();
}
Hope that helps.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With