Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Writing a multithreaded mapping iterator in Java

I've got a general purpose mapping iterator: something like this:

class Mapper<F, T> implements Iterator<T> {

  private Iterator<F> input;
  private Action<F, T> action;

  public Mapper(input, action) {...}

  public boolean hasNext() {
    return input.hasNext();
  }

  public T next() {
    return action.process(input.next());
  }
}

Now, given that action.process() can be time-consuming, I want to gain performance by using multiple threads to process items from the input in parallel. I want to allocate a pool of N worker threads and allocate items to these threads for processing. This should happen "behind the scenes" so the client code just sees an Iterator. The code should avoid holding either the input or the output sequence in memory.

To add a twist, I want two versions of the solution, one which retains order (the final iterator delivers items in the same order as the input iterator) and one of which does not necessarily retain order (each output item is delivered as soon as it is available).

I've sort-of got this working but the code seems convoluted and unreliable and I'm not confident it's using best practice.

Any suggestions on the simplest and most robust way of implementing this? I'm looking for something that works in JDK 6, and I want to avoid introducing dependencies on external libraries/frameworks if possible.

like image 961
Michael Kay Avatar asked Jan 06 '15 09:01

Michael Kay


3 Answers

I'd use a thread pool for the threads and a BlockingQueue to feed out from the pool.

This seems to work with my simple test cases.

interface Action<F, T> {

    public T process(F f);

}

class Mapper<F, T> implements Iterator<T> {

    protected final Iterator<F> input;
    protected final Action<F, T> action;

    public Mapper(Iterator<F> input, Action<F, T> action) {
        this.input = input;
        this.action = action;
    }

    @Override
    public boolean hasNext() {
        return input.hasNext();
    }

    @Override
    public T next() {
        return action.process(input.next());
    }
}

class ParallelMapper<F, T> extends Mapper<F, T> {

    // The pool.
    final ExecutorService pool;
    // The queue.
    final BlockingQueue<T> queue;
    // The next one to deliver.
    private T next = null;

    public ParallelMapper(Iterator<F> input, Action<F, T> action, int threads, int queueLength) {
        super(input, action);
        // Start my pool.
        pool = Executors.newFixedThreadPool(threads);
        // And the queue.
        queue = new ArrayBlockingQueue<>(queueLength);
    }

    class Worker implements Runnable {

        final F f;
        private T t;

        public Worker(F f) {
            this.f = f;
        }

        @Override
        public void run() {
            try {
                queue.put(action.process(f));
            } catch (InterruptedException ex) {
                // Not sure what you can do here.
            }
        }

    }

    @Override
    public boolean hasNext() {
        // All done if delivered it and the input is empty and the queue is empty and the threads are finished.
        while (next == null && (input.hasNext() || !queue.isEmpty() || !pool.isTerminated())) {
            // First look in the queue.
            next = queue.poll();
            if (next == null) {
                // Queue empty.
                if (input.hasNext()) {
                    // Start a new worker.
                    pool.execute(new Worker(input.next()));
                }
            } else {
                // Input exhausted - shut down the pool - unless we already have.
                if (!pool.isShutdown()) {
                    pool.shutdown();
                }
            }
        }
        return next != null;
    }

    @Override
    public T next() {
        T n = next;
        if (n != null) {
            // Delivered that one.
            next = null;
        } else {
            // Fails.
            throw new NoSuchElementException();
        }
        return n;
    }
}

public void test() {
    List<Integer> data = Arrays.asList(5, 4, 3, 2, 1, 0);
    System.out.println("Data");
    for (Integer i : Iterables.in(data)) {
        System.out.println(i);
    }
    Action<Integer, Integer> action = new Action<Integer, Integer>() {

        @Override
        public Integer process(Integer f) {
            try {
                // Wait that many seconds.
                Thread.sleep(1000L * f);
            } catch (InterruptedException ex) {
                // Just give up.
            }
            // Return it unchanged.
            return f;
        }

    };
    System.out.println("Processed");
    for (Integer i : Iterables.in(new Mapper<Integer, Integer>(data.iterator(), action))) {
        System.out.println(i);
    }
    System.out.println("Parallel Processed");
    for (Integer i : Iterables.in(new ParallelMapper<Integer, Integer>(data.iterator(), action, 2, 2))) {
        System.out.println(i);
    }

}

Note: Iterables.in(Iterator<T>) just creates an Iterable<T> that encapsulates the passed Iterator<T>.

For your in-order one you could process Pair<Integer,F> and use a PriorityQueue for the thread output. You could then arrange to pull them in order.

like image 111
OldCurmudgeon Avatar answered Oct 22 '22 12:10

OldCurmudgeon


I dont think it can work with parallel threads because hasNext() may return true but by the time the thread calls next() there may be no more elements. It is better to use only next() which will return null when theres no more elements

like image 3
Evgeniy Dorofeev Avatar answered Oct 22 '22 11:10

Evgeniy Dorofeev


OK, thanks everyone. This is what I've done.

First I wrap my ItemMappingFunction in a Callable:

private static class CallableAction<F extends Item, T extends Item> 
implements Callable<T> {
    private ItemMappingFunction<F, T> action;
    private F input;
    public CallableAction(ItemMappingFunction<F, T> action, F input) {
            this.action = action;
            this.input = input;
    }
    public T call() throws XPathException {
            return action.mapItem(input);
    }
}

I described my problem in terms of the standard Iterator class, but actually I'm using my own SequenceIterator interface, which has a single next() method that returns null at end-of-sequence.

I declare the class in terms of the "ordinary" mapping iterator like this:

public class MultithreadedMapper<F extends Item, T extends Item> extends Mapper<F, T> {

    private ExecutorService service;
    private BlockingQueue<Future<T>> resultQueue = 
        new LinkedBlockingQueue<Future<T>>();

On initialization I create the service and prime the queue:

public MultithreadedMapper(SequenceIterator base, ItemMappingFunction<F, T> action) throws XPathException {
        super(base, action);

        int maxThreads = Runtime.getRuntime().availableProcessors();
        maxThreads = maxThreads > 0 ? maxThreads : 1;
        service = Executors.newFixedThreadPool(maxThreads);

        // prime the queue
        int n = 0;
        while (n++ < maxThreads) {
            F item = (F) base.next();
            if (item == null) {
                return;
            }
            mapOneItem(item);
        }
    }

Where mapOneItem is:

private void mapOneItem(F in) throws XPathException {
    Future<T> future = service.submit(new CallableAction(action, in));
    resultQueue.add(future);
}

When the client asks for the next item, I first submit the next input item to the executor service, then get the next output item, waiting for it to be available if necessary:

    public T next() throws XPathException {
        F nextIn = (F)base.next();
        if (nextIn != null) {
            mapOneItem(nextIn);
        }
        try {
            Future<T> future = resultQueue.poll();
            if (future == null) {
                service.shutdown();
                return null;
            } else {
                return future.get();
            }
        } catch (InterruptedException e) {
            throw new XPathException(e);
        } catch (ExecutionException e) {
            if (e.getCause() instanceof XPathException) {
                throw (XPathException)e.getCause();
            }
            throw new XPathException(e);
        }
    }
like image 3
Michael Kay Avatar answered Oct 22 '22 12:10

Michael Kay