Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Running TensorFlow on multicore devices

I have a basic Android TensorFlowInference example that runs fine in a single thread.

public class InferenceExample {

    private static final String MODEL_FILE = "file:///android_asset/model.pb";
    private static final String INPUT_NODE = "intput_node0";
    private static final String OUTPUT_NODE = "output_node0"; 
    private static final int[] INPUT_SIZE = {1, 8000, 1};
    public static final int CHUNK_SIZE = 8000;
    public static final int STRIDE = 4;
    private static final int NUM_OUTPUT_STATES = 5;

    private static TensorFlowInferenceInterface inferenceInterface;

    public InferenceExample(final Context context) {
        inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);
    }

    public float[] run(float[] data) {

        float[] res = new float[CHUNK_SIZE / STRIDE * NUM_OUTPUT_STATES];

        inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]);
        inferenceInterface.run(new String[]{OUTPUT_NODE});
        inferenceInterface.fetch(OUTPUT_NODE, res);

        return res;
    }
}

The example crashes with various exceptions including java.lang.ArrayIndexOutOfBoundsException and java.lang.NullPointerException when running in a ThreadPool as per the below example so I guess it's not thread safe.

InferenceExample inference = new InferenceExample(context);

ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);    
Collection<Future<?>> futures = new LinkedList<Future<?>>();

for (int i = 1; i <= 100; i++) {
    Future<?> result = executor.submit(new Runnable() {
        public void run() {
           inference.call(randomData);
        }
    });
    futures.add(result);
}

for (Future<?> future:futures) {
    try { future.get(); }
    catch(ExecutionException | InterruptedException e) {
        Log.e("TF", e.getMessage());
    }
}

Is it possible to leverage multicore Android devices with TensorFlowInferenceInterface?

like image 219
Chris Seymour Avatar asked Oct 21 '17 17:10

Chris Seymour


1 Answers

To make the InferenceExample thread safe I changed the TensorFlowInferenceInterface from static and made the run method synchronized:

private TensorFlowInferenceInterface inferenceInterface;

public InferenceExample(final Context context) {
    inferenceInterface = new TensorFlowInferenceInterface(assets, model);
}

public synchronized float[] run(float[] data) { ... }

Then I round robin a list of InterferenceExample instance across numThreads.

for (int i = 1; i <= 100; i++) {
    final int id = i % numThreads;
    Future<?> result = executor.submit(new Runnable() {
        public void run() {
            list.get(id).run(data);
        }
    });
    futures.add(result);
}

This does increase performance however on a 8 core device this peaks at numThreads of 2 and only shows ~50% CPU usage in Android Studio Monitor.

like image 118
Chris Seymour Avatar answered Oct 19 '22 07:10

Chris Seymour