Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I rewrite this main thread - worker threads synchronization

I've a program that goes something like this

public class Test implements Runnable
{
    public        int local_counter
    public static int global_counter
    // Barrier waits for as many threads as we launch + main thread
    public static CyclicBarrier thread_barrier = new CyclicBarrier (n_threads + 1);

    /* Constructors etc. */

    public void run()
    {
        for (int i=0; i<100; i++)
        {
            thread_barrier.await();
            local_counter = 0;
            for(int j=0 ; j = 20 ; j++)
                local_counter++;
            thread_barrier.await();
        }
    }

    public void main()
    {
        /* Create and launch some threads, stored on thread_array */
        for(int i=0 ; i<100 ; i++)
        {
            thread_barrier.await();
            thread_barrier.await();

            for (int t=1; t<thread_array.length; t++)
            {
                global_counter += thread_array[t].local_counter;
            }
        }
    }
}

Basically, I've a few threads with their own local counters, and I'm doing this (in a loop)

        |----|           |           |----|
        |main|           |           |pool|
        |----|           |           |----|
                         |

-------------------------------------------------------
barrier (get local counters before they're overwritten)
-------------------------------------------------------
                         |
                         |   1. reset local counter
                         |   2. do some computations
                         |      involving local counter
                         |
-------------------------------------------------------
             barrier (synchronize all threads)
-------------------------------------------------------
                         |
1. update global counter |
   using each thread's   |
   local counter         |

And this should all be fine and dandy, but it turns out this doesn't scale quite well. On a 16 physical nodes cluster, speedup after 6-8 threads is negligible, so I have to get rid of one of the awaits. I've tried with CyclicBarrier, which scales awfully, Semaphores, which do as much, and a custom library (jbarrier) that works great until there's more threads than physical cores, at which point it performs worse than the sequential version. But I just can't come up with a way of doing this without stopping all threads twice.

EDIT: while I appreciate all and any insight you might have concerning any other possible bottlenecks in my program, I'm looking for an answer concerning this particular issue. I can provide a more specific example if needed

like image 670
Kovalainen Avatar asked Apr 11 '18 00:04

Kovalainen


2 Answers

A few fixes: your iteration over threads should be for(int t=0;...) assuming your thread array[0] should participate in the global counter sum. We can guess it's an array of Test, not threads. local_counter should be volatile, otherwise you may not see the true value across test thread and main thread.

Ok, now, you have a proper 2 phases cycle, afaict. Anything else like a phaser or 1 cycling barrier with a new countdown latch at every loop are just variations of a same theme: getting numerous threads to agree to let the main resume, and getting the main to resume numerous threads in one shot.

A thinner implementation could involve a reentrantlock, a counter of arrived tests threads, a condition to resume test on all test threads, and a condition to resume the main thread. The test thread that arrives when --count==0 should signal the main resume condition. All test threads await the test resume condition. The main should reset the counter to N and signalAll on the test resume condition, then await on the main condition. Threads (test and main) await only once per loop.

Finally, if the end goal is a sum updated by any threads, you should look at LongAdder (if not AtomicLong) to perform addition to a long concurently without having to stop all threads (them them fight and add, not involving the main).

Otherwise you can have the threads deliver their material to a blocking queue read by the main. There is just too many flavors of doing this; I'm having a hard time understanding why you hang all threads to collect data. That's all.The question is oversimplified and we don't have enough constraint to justify what you are doing.

Don't worry about CyclicBarrier, it is implemented with reentrant lock, a counter and a condition to trip the signalAll() to all waiting threads. This is tightly coded, afaict. If you wanted lock-free version, you would be facing too many busy spin loops wasting cpu time, especially when you are concerned of scaling when there is more threads than cores.

Meanwhile, is it possible that you have in fact 8 cores hyperthreaded that look like 16 cpu?

Once sanitized, your code looks like:

package tests;

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.stream.Stream;

public class Test implements Runnable {
    static final int n_threads = 8;
    static final long LOOPS = 10000;
    public static int global_counter;
    public static CyclicBarrier thread_barrier = new CyclicBarrier(n_threads + 1);

    public volatile int local_counter;

    @Override
    public void run() {
        try {
            runImpl();
        } catch (InterruptedException | BrokenBarrierException e) {
            //
        }
    }

    void runImpl() throws InterruptedException, BrokenBarrierException {
        for (int i = 0; i < LOOPS; i++) {
            thread_barrier.await();
            local_counter = 0;
            for (int j=0; j<20; j++)
                local_counter++;
            thread_barrier.await();
        }
    }

    public static void main(String[] args) throws InterruptedException, BrokenBarrierException {
        Test[] ra = new Test[n_threads];
        Thread[] ta = new Thread[n_threads];
        for(int i=0; i<n_threads; i++)
            (ta[i] = new Thread(ra[i]=new Test()).start();

        long nanos = System.nanoTime();
        for (int i = 0; i < LOOPS; i++) {
            thread_barrier.await();
            thread_barrier.await();

            for (int t=0; t<ra.length; t++) {
                global_counter += ra[t].local_counter;
            }
        }

        System.out.println(global_counter+", "+1e-6*(System.nanoTime()-nanos)+" ms");

        Stream.of(ta).forEach(t -> t.interrupt());
    }
}

My version with 1 lock looks like this:

package tests;

import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Stream;

public class TwoPhaseCycle implements Runnable {
    static final boolean DEBUG = false;
    static final int N = 8;
    static final int LOOPS = 10000;

    static ReentrantLock lock = new ReentrantLock();
    static Condition testResume = lock.newCondition();
    static volatile long cycle = -1;
    static Condition mainResume = lock.newCondition();
    static volatile int testLeft = 0;

    static void p(Object msg) {
        System.out.println(Thread.currentThread().getName()+"] "+msg);
    }

    //-----
    volatile int local_counter;

    @Override
    public void run() {
        try {
            runImpl();
        } catch (InterruptedException e) {
            p("interrupted; ending.");
        }
    }

    public void runImpl() throws InterruptedException {
        lock.lock();
        try {
            if(DEBUG) p("waiting for 1st testResumed");
            while(cycle<0) {
                testResume.await();
            }
        } finally {
            lock.unlock();
        }

        long localCycle = 0;//for (int i = 0; i < LOOPS; i++) {
        while(true) {
            if(DEBUG) p("working");
            local_counter = 0;
            for (int j = 0; j<20; j++)
                local_counter++;
            localCycle++;

            lock.lock();
            try {
                if(DEBUG) p("done");
                if(--testLeft <=0)
                    mainResume.signalAll(); //could have been just .signal() since only main is waiting, but safety first.

                if(DEBUG) p("waiting for cycle "+localCycle+" testResumed");
                while(cycle < localCycle) {
                    testResume.await();
                }
            } finally {
                lock.unlock();
            }
        }
    }

    public static void main(String[] args) throws InterruptedException {
        TwoPhaseCycle[] ra = new TwoPhaseCycle[N];
        Thread[] ta = new Thread[N];
        for(int i=0; i<N; i++)
            (ta[i] = new Thread(ra[i]=new TwoPhaseCycle(), "\t\t\t\t\t\t\t\t".substring(0, i%8)+"\tT"+i)).start();

        long nanos = System.nanoTime();

        int global_counter = 0;
        for (int i=0; i<LOOPS; i++) {
            lock.lock();
            try {
                if(DEBUG) p("gathering");
                for (int t=0; t<ra.length; t++) {
                    global_counter += ra[t].local_counter;
                }
                testLeft = N;
                cycle = i;
                if(DEBUG) p("resuming cycle "+cycle+" tests");
                testResume.signalAll();

                if(DEBUG) p("waiting for main resume");
                while(testLeft>0) {
                    mainResume.await();
                }
            } finally {
                lock.unlock();
            }
        }

        System.out.println(global_counter+", "+1e-6*(System.nanoTime()-nanos)+" ms");

        p(global_counter);
        Stream.of(ta).forEach(t -> t.interrupt());
    }
}

Of course, this is by no mean a stable microbenchmark, but the trend shows it's faster. Hope you like it. (I dropped a few favorite tricks for debugging, worth turning debug true...)

like image 147
user2023577 Avatar answered Oct 15 '22 23:10

user2023577


Well. I'm not sure to fully understand, but I think your main problem is that you try to re-use a predefined set of threads too much. You should let Java take care of this (that's what executors/fork-join pool are for). To solve your issue, a split/process/merge (or map/reduce) seems appropriate to me. Since java 8, it's a really simple approach to implement (thanks to the stream/fork-join pool/completable future APIs). I propose 2 alternatives here:

Java 8 Stream

For me, your problem looks like it can be resumed to a map/reduce problem. And if you can use Java 8 streams, you can delegate performance issues to it. What I'd do :
1. Create a parallel stream, containing your processing input (you can even use methods to generate inputs on the fly). Note that you can implement your own Spliterator, to fully control the browsing and splitting of your input (cells on a grid ?).
2. Use a map to process the input.
3. Use a reduce method to merge all previously computed results.

Simple example (based on your example):

// Create a pool with wanted number of threads
    final ForkJoinPool pool = new ForkJoinPool(4);
    // We give the entire procedure to the thread pool
    final int result = pool.submit(() -> {
        // Generate a hundred counters, initialized on 0 value
        return IntStream.generate(() -> 0)
                .limit(100)
                // Specify we want it processed in a parallel way
                .parallel()
                // The map will register processing method
                .map(in -> incrementMultipleTimes(in, 20))
                // We ask the merge of processing results
                .reduce((first, second) -> first + second)
                .orElseThrow(() -> new IllegalArgumentException("Empty dataset"));
    })
            // Wait for the overall result
            .get();

    System.out.println("RESULT: " + result);

    pool.shutdown();
    pool.awaitTermination(10, TimeUnit.SECONDS);

Some things to be aware of :
1. By default, parallel streams execute tasks on JVM Common fork-join pool, which could be limited in number of executors. But there's ways to use your own pool : see this answer.
2. If well-configured, I think that's the best method, because parallelism logic has been taken care of by JDK developper themselves.

Phaser

If you cannot use java8 functionality (or I've misunderstood your problem, or you really want to handle low-level management yourself), the last clue I can give you is: Phaser object. As stated by the doc, it's a re-usable mix of cyclic barrier and countdown latch. I've use it multiple times. It's a complex thing to use, but it's also really powerful. It can be used as a cyclic barrier, so I think it fits your case.

like image 45
amanin Avatar answered Oct 15 '22 23:10

amanin