Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to limit number of threads created and wait main thread until any one thread finds answer?

This is the code to find the first pair of numbers (except 1) whose sum of LCM and HCF is equal to the number.

import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

class PerfectPartition {
    static long gcd(long a, long b) {
        if (a == 0)
            return b;
        return gcd(b % a, a);
    }

    // method to return LCM of two numbers
    static long lcm(long a, long b) {
        return (a / gcd(a, b)) * b;
    }

    long[] getPartition(long n) {
        var ref = new Object() {
            long x;
            long y;
            long[] ret = null;
        };

        Thread mainThread = Thread.currentThread();
        ThreadGroup t = new ThreadGroup("InnerLoop");

        for (ref.x = 2; ref.x < (n + 2) / 2; ref.x++) {
            if (t.activeCount() < 256) {

                new Thread(t, () -> {
                    for (ref.y = 2; ref.y < (n + 2) / 2; ref.y++) {
                        long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);
                        if (z == n) {
                            ref.ret = new long[]{ref.x, ref.y};

                            t.interrupt();
                            break;
                        }
                    }
                }, "Thread_" + ref.x).start();

                if (ref.ret != null) {
                    return ref.ret;
                }
            } else {
                ref.x--;
            }
        }//return new long[]{1, n - 2};

        return Objects.requireNonNullElseGet(ref.ret, () -> new long[]{1, n - 2});
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long[] partition = new PerfectPartition().getPartition(n);
        System.out.println(partition[0] + " " + partition[1]);
    }
}

I want to stop the code execution as soon as the first pair is found. But instead, the main thread just keeps running and prints 1 and n-1.
What could be an optimal solution to limit the no. of threads (<256 as the range of n is 2 to max of long)?

Expected Output (n=4): 2 2
Expected Output (n=8): 4 4

like image 683
p2kr Avatar asked Feb 05 '21 05:02

p2kr


People also ask

How do I make main thread wait for other threads?

The statement “Thread. currentThread(). join()”, will tell Main thread to wait for this thread(i.e. wait for itself) to die.

How do you make a thread wait for some time?

In between, we have also put the main thread to sleep by using TimeUnit. sleep() method. So the main thread can wait for some time and in the meantime, T1 will resume and complete its execution.

Which command allows one thread to wait for the completion of another thread?

Thread class provides the join() method which allows one thread to wait until another thread completes its execution. If t is a Thread object whose thread is currently executing, then t. join() will make sure that t is terminated before the next instruction is executed by the program.

Which method is used to make main thread to wait for all child sets?

Which of this method can be used to make the main thread to be executed last among all the threads? Explanation: By calling sleep() within main(), with long enough delay to ensure that all child threads terminate prior to the main thread.

How many threads can a thread have started but not exit?

The number of permits goes down by one ( s.acquire ()) just before every thread start, and back up by one ( s.release ()) just before every thread exit. Since the number of permits can never be negative, the number of threads that have started but not yet exited can never be more than 5.

How to set maximum number of threads in a thread?

Create function void* create (void *) and leave it empty as it only demonstrates the work of the thread. In main () function initialize two variables max = 0 and ret = 0 both of type int to store the maximum number of threads and the return value respectively.

How to wait between two threads in main thread?

A solution would be to make main wait until the two threads are done however wait functions are for processes, in threads the pthread_join function is used for such cases. It takes the thread ID of the thread to wait for and a pointer to void* variable that receives the thread's return value after it completes.

How does the Max_workers parameter on the threadpoolexecutor work?

The max_workers parameter on the ThreadPoolExecutor only controls how many workers are spinning up threads not how many threads get spun up. Ok, anyway, here is a very simple example of using a Queue for this:


2 Answers

What could be an optimal solution to limit the no. of threads (<256 as the range of n is 2 to max of long)?

First, you should consider the hardware where the code will be executed (e.g., the number of cores) and the type of algorithm that you are parallelizing, namely is it CPU-bound?, memory-bound?, IO-bound, and so on.

Your code is CPU-bound, therefore, from a performance point of view, typically does not payoff having more threads running than the number of available cores in the system. As is always the case profile as much as you can.

Second, you need to distribute work among threads in a way that justifies the parallelism, in your case:

  for (ref.x = 2; ref.x < (n + 2) / 2; ref.x++) {
        if (t.activeCount() < 256) {

            new Thread(t, () -> {
                for (ref.y = 2; ref.y < (n + 2) / 2; ref.y++) {
                    long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);
                    if (z == n) {
                        ref.ret = new long[]{ref.x, ref.y};

                        t.interrupt();
                        break;
                    }
                }
            }, "Thread_" + ref.x).start();

            if (ref.ret != null) {
                return ref.ret;
            }
        } else {
            ref.x--;
        }
    }//return new long[]{1, n - 2};

which you kind of did, however IMO in a convoluted way; much easier IMO is to parallelize the loop explicitly, i.e., splitting its iterations among threads, and remove all the ThreadGroup related logic.

Third, lookout for race-conditions such as :

var ref = new Object() {
    long x;
    long y;
    long[] ret = null;
};

this object is shared among threads, and updated by them, consequently leading to race-conditions. As we are about to see you do not actually need such a shared object anyway.

So let us do this step by step:

First, find out the number of threads that you should execute the code with i.e., the same number of threads as cores:

int cores = Runtime.getRuntime().availableProcessors();

Define the parallel work (this a possible example of a loop distribution):

public void run() {
    for (int x = 2; && x < (n + 2) / 2; x ++) {
        for (int y = 2 + threadID; y < (n + 2) / 2; y += total_threads) {
            long z = lcm(x, y) + gcd(x, y);
            if (z == n) {
                // do something 
            }
        }
    }
}

in the code below, we split the work to be done in parallel in a round-robin fashion among threads as showcased in the image below:

enter image description here

I want to stop the code execution as soon as the first pair is found.

There are several ways of achieving this. I will provide the simplest IMO, albeit not the most sophisticated. You can use a variable to signal to the threads when the result was already found, for instance:

final AtomicBoolean found;

each thread will share the same AtomicBoolean variable so that the change performed in one of them are also visible to the others:

@Override
public void run() {
    for (int x = 2 ; !found.get() && x < (n + 2) / 2; x ++) {
        for (int y = 2 + threadID; y < (n + 2) / 2; y += total_threads)  {
            long z = lcm(x, y) + gcd(x, y);
            if (z == n) {
                synchronized (found) {
                    if(!found.get()) {
                        rest[0] = x;
                        rest[1] = y;
                        found.set(true);
                    }
                    return;
                }
            }
        }
    }
}

Since you were asking for a code snippet example here is a simple non-bulletproof (and not properly tested) running coding example:

class ThreadWork implements Runnable{

    final long[] rest;
    final AtomicBoolean found;
    final int threadID;
    final int total_threads;
    final long n;

    ThreadWork(long[] rest, AtomicBoolean found, int threadID, int total_threads, long n) {
        this.rest = rest;
        this.found = found;
        this.threadID = threadID;
        this.total_threads = total_threads;
        this.n = n;
    }

    static long gcd(long a, long b) {
        return (a == 0) ? b : gcd(b % a, a);
    }

    static long lcm(long a, long b, long gcd) {
        return (a / gcd) * b;
    }

    @Override
    public void run() {
        for (int x = 2; !found.get() && x < (n + 2) / 2; x ++) {
            for (int y = 2 + threadID; !found.get() && y < (n + 2) / 2; y += total_threads) {
                long result = gcd(x, y);
                long z = lcm(x, y, result) + result;
                if (z == n) {
                    synchronized (found) {
                        if(!found.get()) {
                            rest[0] = x;
                            rest[1] = y;
                            found.set(true);
                        }
                        return;
                    }
                }
            }
        }
    }
}

class PerfectPartition {

    public static void main(String[] args) throws InterruptedException {
        Scanner sc = new Scanner(System.in);
        final long n = sc.nextLong();
       final int total_threads = Runtime.getRuntime().availableProcessors();

        long[] rest = new long[2];
        AtomicBoolean found = new AtomicBoolean();

        double startTime = System.nanoTime();
        Thread[] threads = new Thread[total_threads];
        for(int i = 0; i < total_threads; i++){
            ThreadWork task = new ThreadWork(rest, found, i, total_threads, n);
            threads[i] = new Thread(task);
            threads[i].start();
        }

        for(int i = 0; i < total_threads; i++){
            threads[i].join();
        }

        double estimatedTime = System.nanoTime() - startTime;
        System.out.println(rest[0] + " " + rest[1]);


        double elapsedTimeInSecond = estimatedTime / 1_000_000_000;
        System.out.println(elapsedTimeInSecond + " seconds");
    }
}

OUTPUT:

4 -> 2 2
8 -> 4 4

Used this code as inspiration to come up with your own solution that best fits your requirements. After you fully understand those basics, try to improve the approach with more sophisticated Java features such as Executors, Futures, CountDownLatch.


NEW UPDATE: Sequential Optimization

Looking at the gcd method:

  static long gcd(long a, long b) {
        return (a == 0)? b : gcd(b % a, a);
  }

and the lcm method:

static long lcm(long a, long b) {
    return (a / gcd(a, b)) * b;
}

and how they are being used:

long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);

you can optimize your sequential code by not calling again gcd(a, b) in the lcm method. So change lcm method to:

static long lcm(long a, long b, long gcd) {
    return (a / gcd) * b;
}

and

long z = lcm(ref.x, ref.y) + gcd(ref.x, ref.y);

to

long result = gcd(ref.x, ref.y)
long z = lcm(ref.x, ref.y, gcd) + gcd;

The code that I have provided in this answer already reflects those changes.

like image 134
dreamcrash Avatar answered Sep 23 '22 12:09

dreamcrash


First of all, you miss calling "start" on the thread.

new Thread(t, () -> {
    ...
    ...
}, "Thread_" + ref.x).start();

And coming to your question, to limit number the of threads you can use thread pools, for example, Executors.newFixedThreadPool(int nThreads).

And to stop executing you can have your main thread wait on a single count CountDownLatch and count down the latch when there is a successful match in your worker thread and in the main shutdown the thread pool when the wait on the latch completes.

As you asked, here is a sample code that uses thread pools and CountDownLatch:

import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

public class LcmHcmSum {

    static long gcd(long a, long b) {
        if (a == 0)
            return b;
        return gcd(b % a, a);
    }

    // method to return LCM of two numbers
    static long lcm(long a, long b) {
        return (a / gcd(a, b)) * b;
    }
    
    long[] getPartition(long n) {
        singleThreadJobSubmitter.execute(() -> {
            for (int x = 2; x < (n + 2) / 2; x++) {
                    submitjob(n, x);
                    if(numberPair != null) break;  // match found, exit the loop
            }
            try {
                jobsExecutor.shutdown();  // process the already submitted jobs
                jobsExecutor.awaitTermination(10, TimeUnit.SECONDS);  // wait for the completion of the jobs
                
                if(numberPair == null) {  // no match found, all jobs processed, nothing more to do, count down the latch 
                    latch.countDown();
                }
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        
        try {
            latch.await();
            singleThreadJobSubmitter.shutdownNow();
            jobsExecutor.shutdownNow();
            
        } catch (InterruptedException e1) {
            e1.printStackTrace();
        }
        return Objects.requireNonNullElseGet(numberPair, () -> new long[]{1, n - 2});
    }

    private Future<?> submitjob(long n, long x) {
        return jobsExecutor.submit(() -> {
            for (int y = 2; y < (n + 2) / 2; y++) {
                long z = lcm(x, y) + gcd(x, y);
                if (z == n) {
                    synchronized(LcmHcmSum.class) {  numberPair = new long[]{x, y}; }
                    latch.countDown();
                    break;
                }
            }
        });
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long[] partition = new LcmHcmSum().getPartition(n);
        System.out.println(partition[0] + " " + partition[1]);
    }
    
    private static CountDownLatch latch = new CountDownLatch(1);
    private static ExecutorService jobsExecutor = Executors.newFixedThreadPool(4);
    private static volatile long[] numberPair = null;
    private static ExecutorService singleThreadJobSubmitter = Executors.newSingleThreadExecutor();      
    

}
like image 25
user15117826 Avatar answered Sep 19 '22 12:09

user15117826