Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using Keras for real-time training and predicting

I want to use Keras for a real-time training and prediction setting. In my scenario I get real-time data via MQTT that should be used to train a (LSTM) Neural Network and/or to apply them to the to get a prediction.

I am using the Tensorflow backend with GPU support and fairly potent GPU capacity but in my scenario Keras does not really profit from GPU acceleration. (I did some performance tests by using the examples in the keras repository to make sure that the GPU acceleration works in general). In my first approach, I used the model.train_on_batch(...) method to train the network with each item coming via MQTT:

model = load_model()

def on_message(msg):
    """
    Method called by MQTT client each time new data comes in
    """

    if msg.topic == 'my/topic':
        X, Y = prepare_data(msg.payload)

        prediction = model.predict(X)
        loss = model.train_on_batch(X, Y)

        send_to_visualization_tool(prediction, loss)

One training step in this setting takes about 200ms. However, when I introduce a buffer e.g. buffering 100 data points, the training time for the whole batch only increases slightly. This suggests that the setup time for batch training has a huge overhead. I also noticed that when using size 1 batches, the CPU consumption is quite high, while the GPU is hardly used at all.

As an alternative I now introduced a synchronized Queue, where the MQTT client pushes data, whenever data comes in and the Neural Network then consumes all data as a batch, that came in while processing the previous batch:

train_data_queue = Queue.Queue()

# MQTT client running in separate thread
def on_message(msg):
    train_data_queue.put(msg.payload)

model = load_model()

while True:
    train_data_batch = dequeue_all(train_data_queue)  # dequeue all items from queue
                                                      # or block until at least one
                                                      # item is present
    X, Y = prepare_data(train_data_batch)

    predictions = model.predict_on_batch(X)
    losses = model.train_on_batch(X, Y)

    send_to_visualization_tool(predictions, losses)

This approach works okay but it would be nice if I could get rid of the additional complexity of synchronized Queues and multi threading. I.e. get first approach work.

My question therefore is: Is there a way to reduce the overhead of one batch trainings? E.g. by reimplementing the model in pure tensorflow? Or can you think of a better way to do real-time training with Keras?

like image 405
Zwackelmann Avatar asked Nov 09 '22 06:11

Zwackelmann


1 Answers

The performance of keras should be broadly similar to the performance of raw tensorflow, so I do not recommend rewriting your model.

Indeed modern hardware usually takes about the same time to train with a single example as it does with a batch of examples, which is why we spend so much effort batching things up. You can get rid of the complexity of synchronized queues if you want to use tf.contrib.batching.batch_function but you'll still need to feed it from many threads if you want to get the extra throughput.

like image 168
Alexandre Passos Avatar answered Nov 14 '22 22:11

Alexandre Passos