Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Multi-threaded matrix multiplication

I've coded a multi-threaded matrix multiplication. I believe my approach is right, but I'm not 100% sure. In respect to the threads, I don't understand why I can't just run a (new MatrixThread(...)).start() instead of using an ExecutorService.

Additionally, when I benchmark the multithreaded approach versus the classical approach, the classical is much faster...

What am I doing wrong?

Matrix Class:

import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

class Matrix
{
   private int dimension;
   private int[][] template;

   public Matrix(int dimension)
   {
      this.template = new int[dimension][dimension];
      this.dimension = template.length;
   }

   public Matrix(int[][] array) 
   {
      this.dimension = array.length;
      this.template = array;      
   }

   public int getMatrixDimension() { return this.dimension; }

   public int[][] getArray() { return this.template; }

   public void fillMatrix()
   {
      Random randomNumber = new Random();
      for(int i = 0; i < dimension; i++)
      {
         for(int j = 0; j < dimension; j++)
         {
            template[i][j] = randomNumber.nextInt(10) + 1;
         }
      }
   }

   @Override
   public String toString()
   {
      String retString = "";
      for(int i = 0; i < this.getMatrixDimension(); i++)
      {
         for(int j = 0; j < this.getMatrixDimension(); j++)
         {
            retString += " " + this.getArray()[i][j];
         }
         retString += "\n";
      }
      return retString;
   }

   public static Matrix classicalMultiplication(Matrix a, Matrix b)
   {      
      int[][] result = new int[a.dimension][b.dimension];
      for(int i = 0; i < a.dimension; i++)
      {
         for(int j = 0; j < b.dimension; j++)
         {
            for(int k = 0; k < b.dimension; k++)
            {
               result[i][j] += a.template[i][k] * b.template[k][j];
            }
         }
      }
      return new Matrix(result);
   }

   public Matrix multiply(Matrix multiplier) throws InterruptedException
   {
      Matrix result = new Matrix(dimension);
      ExecutorService es = Executors.newFixedThreadPool(dimension*dimension);
      for(int currRow = 0; currRow < multiplier.dimension; currRow++)
      {
         for(int currCol = 0; currCol < multiplier.dimension; currCol++)
         {            
            //(new MatrixThread(this, multiplier, currRow, currCol, result)).start();            
            es.execute(new MatrixThread(this, multiplier, currRow, currCol, result));
         }
      }
      es.shutdown();
      es.awaitTermination(2, TimeUnit.DAYS);
      return result;
   }

   private class MatrixThread extends Thread
   {
      private Matrix a, b, result;
      private int row, col;      

      private MatrixThread(Matrix a, Matrix b, int row, int col, Matrix result)
      {         
         this.a = a;
         this.b = b;
         this.row = row;
         this.col = col;
         this.result = result;
      }

      @Override
      public void run()
      {
         int cellResult = 0;
         for (int i = 0; i < a.getMatrixDimension(); i++)
            cellResult += a.template[row][i] * b.template[i][col];

         result.template[row][col] = cellResult;
      }
   }
} 

Main class:

import java.util.Scanner;

public class MatrixDriver
{
   private static final Scanner kb = new Scanner(System.in);

   public static void main(String[] args) throws InterruptedException
   {      
      Matrix first, second;
      long timeLastChanged,timeNow;
      double elapsedTime;

      System.out.print("Enter value of n (must be a power of 2):");
      int n = kb.nextInt();

      first = new Matrix(n);
      first.fillMatrix();      
      second = new Matrix(n);
      second.fillMatrix();

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using threads:\n" +
                                                        first.multiply(second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Threaded took "+elapsedTime+" seconds");

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using classical:\n" +
                                  Matrix.classicalMultiplication(first,second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Classical took "+elapsedTime+" seconds");
   }
} 

P.S. Please let me know if any further clarification is needed.

like image 385
Alex Wood Avatar asked Oct 15 '09 18:10

Alex Wood


People also ask

How does multithreading help matrix multiplication?

Multi-threading can be done to improve it. In multi-threading, instead of utilizing a single core of your processor, we utilizes all or more core to solve the problem. We create different threads, each thread evaluating some part of matrix multiplication.

How do you multiply multidimensional matrices?

The number of elements in the second dimension being multiplied in the first multidimensional matrix must equal the number of elements in the first dimension being multiplied of the second multidimensional matrix. That is, Ndb(A) = Nda(B).

What is matrix multiplication in AI?

Matrix multiplication is one of the most important mathematical operations when it comes to deep neural networks. Be it a convolution operation of a CNN to recognise images or a language model to perform sentiment analysis; these basic arithmetic operations play a huge role.

What is the best algorithm for matrix multiplication?

In linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorithm for large matrices, with a better asymptotic complexity, although the naive algorithm is often better for smaller matrices.


1 Answers

There is a bunch of overhead involved in creating threads, even when using an ExecutorService. I suspect the reason why you're multithreaded approach is so slow is that you're spending 99% creating a new thread and only 1%, or less, doing the actual math.

Typically, to solve this problem you'd batch a whole bunch of operations together and run those on a single thread. I'm not 100% how to do that in this case, but I suggest breaking your matrix into smaller chunks (say, 10 smaller matrices) and run those on threads, instead of running each cell in its own thread.

like image 145
Outlaw Programmer Avatar answered Nov 14 '22 10:11

Outlaw Programmer