I'm trying to split up the minimize function over two machines. On one machine, I'm calling "compute_gradients", on another I call "apply_gradients" with gradients that were sent over the network. The issue is that calling apply_gradients(...).run(feed_dict) doesn't seem to work no matter what I do. I've tried inserting placeholders in place of the tensor gradients for apply_gradients,
variables = [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2]
loss = -tf.reduce_sum(y_ * tf.log(y_conv))
optimizer = tf.train.AdamOptimizer(1e-4)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
compute_gradients = optimizer.compute_gradients(loss, variables)
placeholder_gradients = []
for grad_var in compute_gradients:
placeholder_gradients.append((tf.placeholder('float', shape=grad_var[1].get_shape()) ,grad_var[1]))
apply_gradients = optimizer.apply_gradients(placeholder_gradients)
then later when I receive the gradients I call
feed_dict = {}
for i, grad_var in enumerate(compute_gradients):
feed_dict[placeholder_gradients[i][0]] = tf.convert_to_tensor(gradients[i])
apply_gradients.run(feed_dict=feed_dict)
However, when I do this, I get
ValueError: setting an array element with a sequence.
This is only the latest thing I've tried, I've also tried the same solution without placeholders, as well as waiting to create the apply_gradients operation until I receive the gradients, which results in non-matching graph errors.
Any help on which direction I should go with this?
Assuming that each gradients[i]
is a NumPy array that you've fetched using some out-of-band mechanism, the fix is simply to remove the tf.convert_to_tensor()
invocation when building feed_dict
:
feed_dict = {}
for i, grad_var in enumerate(compute_gradients):
feed_dict[placeholder_gradients[i][0]] = gradients[i]
apply_gradients.run(feed_dict=feed_dict)
Each value in a feed_dict
should be a NumPy array (or some other object that is trivially convertible to a NumPy array). In particular, a tf.Tensor
is not a valid value for a feed_dict
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With