When we want to use distributed TensorFlow, we will create a parameter server using
tf.train.Server.join()
However, I can't find any way to shut down the server except killing the processing. The TensorFlow documentation for join() is
Blocks until the server has shut down.
This method currently blocks forever.
This is quite bothering to me because I would like to create many servers for computation and shut them down when everything finishes.
Is there possible solutions for this.
Thanks
This page appears pretty often on Google, so I thought I would try to improve on Yaroslav's answer by providing what I hope is a more clear answer for those just getting into distributed Tensorflow.
import tensorflow as tf
import threading
def main(job_name, task):
cluster = tf.train.ClusterSpec({
'ps': ['localhost:22222', 'localhost:22223'],
'worker': ['localhost: 22224','localhost: 22225','localhost: 22226']
})
server = tf.train.Server(cluster, job_name=job_name, task_index=task)
if job_name == 'ps':
# create a shared queue on the parameter server which is visible on /job:ps/task:%d
with tf.device('/job:ps/task:%d' % task):
queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % task)
# wait for the queue to be filled
with tf.Session(server.target) as sess:
for i in range(cluster.num_tasks('worker')):
sess.run(queue.dequeue())
print('ps:%d received "done" from worker:%d' % (task, i))
print('ps:%d quitting' % task)
elif job_name == 'worker':
queues = []
# create a shared queue on the worker which is visible on /job:ps/task:%d
for i in range(cluster.num_tasks('ps')):
with tf.device('/job:ps/task:%d' % i):
queues.append(tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % i))
# fill the queue
with tf.Session(server.target) as sess:
for i in range(cluster.num_tasks('ps')):
_, size = sess.run([queues[i].enqueue(task), queues[i].size()])
print('Worker:%d sending "done" to ps:%d [elements=%d]' % (task, i, size))
if __name__ == '__main__':
threads = [
threading.Thread(target=main, args=('ps', 0)),
threading.Thread(target=main, args=('ps', 1)),
threading.Thread(target=main, args=('worker', 0)),
threading.Thread(target=main, args=('worker', 1)),
threading.Thread(target=main, args=('worker', 2))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
It's pretty simple to extend upon the "canonical" Distributed Tensorflow example by replacing the worker section of the code with this snippet:
# create a worker that does nothing
elif job_name == 'worker':
with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:%d' % task, cluster=cluster)):
global_step = tf.train.get_or_create_global_step()
no_op = tf.no_op()
done_ops = []
# create a shared queue on the worker which is visible on /job:ps/task:%d
for i in range(cluster.num_tasks('ps')):
with tf.device('/job:ps/task:%d' % i):
done_queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue' + str(i))
done_ops.append(done_queue.enqueue(task))
hooks=[tf.train.StopAtStepHook(last_step=1),
tf.train.FinalOpsHook([done_ops])]
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(task == 0),
hooks=hooks) as sess:
sess.run([no_op])
Note that the MonitoredTrainingSession version seems to be much slower at connecting all of the workers together.
You can have parameter server processes die on demand by using session.run(dequeue_op)
instead of server.join()
and having another process enqueue something onto that queue when you want this process to die.
So for k
parameter server shards you could create k
queues, with unique shared_name
property and try to dequeue
from that queue. When you want to bring down the servers, you loop over all queues and enqueue
a token onto each queue. This would cause session.run
to unblock and Python process will run to the end and quit, bringing down the server.
Below is a self-contained example with 2 shards taken from: https://gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e
(for multi worker/multi shard example, see https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca)
import subprocess
import tensorflow as tf
import time
import sys
flags = tf.flags
flags.DEFINE_string("port1", "12222", "port of worker1")
flags.DEFINE_string("port2", "12223", "port of worker2")
flags.DEFINE_string("task", "", "internal use")
FLAGS = flags.FLAGS
# setup local cluster from flags
host = "127.0.0.1:"
cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]}
clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()
if __name__=='__main__':
if not FLAGS.task: # start servers and run client
# launch distributed service
def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT)
runcmd("python %s --task=0"%(sys.argv[0]))
runcmd("python %s --task=1"%(sys.argv[0]))
time.sleep(1)
# bring down distributed service
sess = tf.Session("grpc://"+host+FLAGS.port1)
queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0")
queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1")
with tf.device("/job:worker/task:0"):
add_op0 = tf.add(tf.ones(()), tf.ones(()))
with tf.device("/job:worker/task:1"):
add_op1 = tf.add(tf.ones(()), tf.ones(()))
print("Running computation on server 0")
print(sess.run(add_op0))
print("Running computation on server 1")
print(sess.run(add_op1))
print("Bringing down server 0")
sess.run(queue0.enqueue(1))
print("Bringing down server 1")
sess.run(queue1.enqueue(1))
else: # Launch TensorFlow server
server = tf.train.Server(clusterspec, config=None,
job_name="worker",
task_index=int(FLAGS.task))
print("Starting server "+FLAGS.task)
sess = tf.Session(server.target)
queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task)
sess.run(queue.dequeue())
print("Terminating server"+FLAGS.task)
There's currently no clean way to shut down a TensorFlow gRPC server. It is possible to shut down a gRPC server, but doing it safely requires additional memory management for all of the in-flight request and response buffers, which would require a lot of additional plumbing (of the worst kind: asynchronous shared memory management...) for a feature that nobody had requested—until now!
In practice you should be able to use the same tf.train.Server
object for many different computations. If this doesn't work for your use case, please feel free to open an GitHub issue and tell us more about your use case.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With