Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matrix Multiplication with threads Java

I'm trying to create a Java program with threads for matrix multiplication. This is the source code:

import java.util.Random;

public class MatrixTest {
    //Creating the matrix
    static int[][] mat = new int[3][3];
    static int[][] mat2 = new int[3][3];
    static int[][] result = new int[3][3];

    public static void main(String[] args) {
        //Creating the object of random class
        Random rand = new Random();

        //Filling first matrix with random values
        for (int i = 0; i < mat.length; i++) {
            for (int j = 0; j < mat[i].length; j++) {
                mat[i][j] = rand.nextInt(10);
            }
        }

        //Filling second matrix with random values
        for (int i = 0; i < mat2.length; i++) {
            for (int j = 0; j < mat2[i].length; j++) {
                mat2[i][j] = rand.nextInt(10);
            }
        }

        try {
            //Object of multiply Class
            Multiply multiply = new Multiply(3, 3);

            //Threads
            MatrixMultiplier thread1 = new MatrixMultiplier(multiply);
            MatrixMultiplier thread2 = new MatrixMultiplier(multiply);
            MatrixMultiplier thread3 = new MatrixMultiplier(multiply);

            //Implementing threads
            Thread th1 = new Thread(thread1);
            Thread th2 = new Thread(thread2);
            Thread th3 = new Thread(thread3);

            //Starting threads
            th1.start();
            th2.start();
            th3.start();

            th1.join();
            th2.join();
            th3.join();
        } catch (Exception e) {
            e.printStackTrace();
        }

        //Printing the result
        System.out.println("\n\nResult:");
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[i].length; j++) {
                System.out.print(result[i][j] + " ");
            }
            System.out.println();
        }
    }//End main
}//End Class

//Multiply Class
class Multiply extends MatrixTest {
    private int i;
    private int j;
    private int chance;

    public Multiply(int i, int j) {
        this.i = i;
        this.j = j;
        chance = 0;
    }

    //Matrix Multiplication Function
    public synchronized void multiplyMatrix() {
        int sum = 0;
        int a = 0;
        for (a = 0; a < i; a++) {
            sum = 0;
            for (int b = 0; b < j; b++) {
                sum = sum + mat[chance][b] * mat2[b][a];
            }
            result[chance][a] = sum;
        }

        if (chance >= i)
            return;
        chance++;
    }
}//End multiply class

//Thread Class
class MatrixMultiplier implements Runnable {
    private final Multiply mul;

    public MatrixMultiplier(Multiply mul) {
        this.mul = mul;
    }

    @Override
    public void run() {
        mul.multiplyMatrix();
    }
}

I just tried on Eclipse and it works, but now I want to create another version of that program in which, I use one thread for each cell that I'll have on the result matrix. For example I've got two 3x3 matrices. So the result matrix will be 3x3. Then, I want to use 9 threads to calculate each one of the 9 cells of the result matrix.

Can anyone help me?

like image 944
WhatElse88 Avatar asked Sep 03 '15 09:09

WhatElse88


4 Answers

You can create n Threads as follows (Note: numberOfThreads is the number of threads that you want to create. This will be the number of cells):

List<Thread> threads = new ArrayList<>(numberOfThreads);

for (int x = 0; x < numberOfThreads; x++) {
   Thread t = new Thread(new MatrixMultiplier(multiply));
   t.start();
   threads.add(t);
}

for (Thread t : threads) {
   t.join();
}
like image 87
Nicholas Robinson Avatar answered Oct 23 '22 04:10

Nicholas Robinson


Please use the new Executor framework to create Threads, instead of manually doing the plumbing.

ExecutorService executor = Executors.newFixedThreadPool(numberOfThreadsInPool);
for (int i = 0; i < numberOfThreads; i++) {
  Runnable worker = new Thread(new MatrixMultiplier(multiply));;
  executor.execute(worker);
}
executor.shutdown();
while (!executor.isTerminated()) {
}
like image 38
Rob Audenaerde Avatar answered Oct 23 '22 05:10

Rob Audenaerde


Consider Matrix.java and Main.java as follows.

public class Matrix extends Thread {
    private static int[][] a;
    private static int[][] b;
    private static int[][] c;

    /* You might need other variables as well */
    private int i;
    private int j;
    private int z1;

    private int s;
    private int k;

    public Matrix(int[][] A, final int[][] B, final int[][] C, int i, int j, int z1) { // need to change this, might
        // need some information
        a = A;
        b = B;
        c = C;
        this.i = i;
        this.j = j;
        this.z1 = z1; // a[0].length
    }

    public void run() {
        synchronized (c) {
            // 3. How to allocate work for each thread (recall it is the run function which
            // all the threads execute)

            // Here this code implements the allocated work for perticular thread
            // Each element of the resulting matrix will generate by a perticular thread
            for (s = 0, k = 0; k < z1; k++)
                s += a[i][k] * b[k][j];
            c[i][j] = s;
        }
    }

    public static int[][] returnC() {
        return c;
    }

    public static int[][] multiply(final int[][] a, final int[][] b) {
        /*
         * check if multipication can be done, if not return null allocate required
         * memory return a * b
         */
        final int x = a.length;
        final int y = b[0].length;

        final int z1 = a[0].length;
        final int z2 = b.length;

        if (z1 != z2) {
            System.out.println("Cannnot multiply");
            return null;
        }

        final int[][] c = new int[x][y];
        int i, j;

        // 1. How to use threads to parallelize the operation?
        // Every element in the resulting matrix will be determined by a different
        // thread

        // 2. How may threads to use?
        // x * y threads are used to generate the result.
        for (i = 0; i < x; i++)
            for (j = 0; j < y; j++) {
                try {
                    Matrix temp_thread = new Matrix(a, b, c, i, j, z1);
                    temp_thread.start();

                    // 4. How to synchronize?

                    // synchronized() is used with join() to guarantee that the perticular thread
                    // will be accessed first
                    temp_thread.join();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        return Matrix.returnC();
    }
}

You can use Main.java to give 2 matrices that need to be multiplied.

class Main {
    public static int[][] a = {
            {1, 1, 1},
            {1, 1, 1},
            {1, 1, 1}};

    public static int[][] b = {
            {1},
            {1},
            {1}};

    public static void print_matrix(int[][] a) {
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < a[i].length; j++)
                System.out.print(a[i][j] + " ");
            System.out.println();
        }
    }

    public static void main(String[] args) {
        int[][] x = Matrix.multiply(a, b);
        print_matrix(x); // see if the multipication is correct
    }
}
like image 28
Viraj Dhanushka Avatar answered Oct 23 '22 05:10

Viraj Dhanushka


With this code i think that i resolve my problem. I don't use synchronized in the methods but i think that is not necessary in that case.

import java.util.Scanner;

class MatrixProduct extends Thread {
    private int[][] A;
    private int[][] B;
    private int[][] C;
    private int rig, col;
    private int dim;

    public MatrixProduct(int[][] A, int[][] B, int[][] C, int rig, int col, int dim_com) {
        this.A = A;
        this.B = B;
        this.C = C;
        this.rig = rig;
        this.col = col;
        this.dim = dim_com;
    }

    public void run() {
        for (int i = 0; i < dim; i++) {
            C[rig][col] += A[rig][i] * B[i][col];
        }
        System.out.println("Thread " + rig + "," + col + " complete.");
    }
}

public class MatrixMultiplication {
    public static void main(String[] args) {
        Scanner In = new Scanner(System.in);

        System.out.print("Row of Matrix A: ");
        int rA = In.nextInt();
        System.out.print("Column of Matrix A: ");
        int cA = In.nextInt();
        System.out.print("Row of Matrix B: ");
        int rB = In.nextInt();
        System.out.print("Column of Matrix B: ");
        int cB = In.nextInt();
        System.out.println();

        if (cA != rB) {
            System.out.println("We can't do the matrix product!");
            System.exit(-1);
        }
        System.out.println("The matrix result from product will be " + rA + " x " + cB);
        System.out.println();
        int[][] A = new int[rA][cA];
        int[][] B = new int[rB][cB];
        int[][] C = new int[rA][cB];
        MatrixProduct[][] thrd = new MatrixProduct[rA][cB];

        System.out.println("Insert A:");
        System.out.println();
        for (int i = 0; i < rA; i++) {
            for (int j = 0; j < cA; j++) {
                System.out.print(i + "," + j + " = ");
                A[i][j] = In.nextInt();
            }
        }
        System.out.println();
        System.out.println("Insert B:");
        System.out.println();
        for (int i = 0; i < rB; i++) {
            for (int j = 0; j < cB; j++) {
                System.out.print(i + "," + j + " = ");
                B[i][j] = In.nextInt();
            }
        }
        System.out.println();

        for (int i = 0; i < rA; i++) {
            for (int j = 0; j < cB; j++) {
                thrd[i][j] = new MatrixProduct(A, B, C, i, j, cA);
                thrd[i][j].start();
            }
        }

        for (int i = 0; i < rA; i++) {
            for (int j = 0; j < cB; j++) {
                try {
                    thrd[i][j].join();
                } catch (InterruptedException e) {
                }
            }
        }

        System.out.println();
        System.out.println("Result");
        System.out.println();
        for (int i = 0; i < rA; i++) {
            for (int j = 0; j < cB; j++) {
                System.out.print(C[i][j] + " ");
            }
            System.out.println();
        }
    }
}
like image 23
WhatElse88 Avatar answered Oct 23 '22 05:10

WhatElse88