Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Strassen algorithm not the fastest?

I copied strassen's algorithm from somewhere and then executed it. Here is the output

n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms

where strassen1 is a dynamic approach, strassen2 for cache and classical is the old matrix multiplication. This means that our old and easy classical one is the best. Is this true or i am wrong somewhere? Here's the code in Java.

import java.util.Random;

class TestIntMatrixMultiplication {

    public static void main (String...args) throws Exception {
        final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
        final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
        final Random random = new Random(seed);

        int[][] a, b, c;

        a = new int[n][n];
        b = new int[n][n];
        c = new int[n][n];

        for(int i=0; i<n; i++) {
            for(int j=0; j<n; j++) {
                a[i][j] = random.nextInt(100);
                b[i][j] = random.nextInt(100);
            }
        }



        System.out.println("n = " + n);

        if (a.length < 64) {
            System.out.println("A");
            dumpMatrix(a);
            System.out.println("B");
            dumpMatrix(b);
            System.out.println("classic");
            Classical.mult(c, a, b);
            dumpMatrix(c);
            System.out.println("strassen");
            strassen2.mult(c, a, b);
            dumpMatrix(c);

            return;
        }

        for (int i = 0; i <3; ++i) {
            timeMultiplies1(a, b, c);
            if (n <= 256)
                timeMultiplies2( a, b, c);
            timeMultiplies3( a, b, c);
        }
    }

    static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        Classical.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("classical took " + (finish - start) + "ms");
    }
    static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen1.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen 1 took " + (finish - start) + "ms");
    }
    static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen2.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen2 took " + (finish - start) + "ms");
    }

    static void dumpMatrix (int[][] m) {
        for (int[] row : m) {
            System.out.print("[\t");
            for (int val : row) {
                System.out.print(val);
                System.out.print('\t');
            }
            System.out.println(']');
        }
    }
}

class strassen1{

    public String getName () {
        return "Strassen(dynamic)";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        return strassenMatrixMultiplication(a, b);
    }

    public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        if(n == 1) {
            result[0][0] = A[0][0] * B[0][0];
        } else {
            int [][] A11 = new int[n/2][n/2];
            int [][] A12 = new int[n/2][n/2];
            int [][] A21 = new int[n/2][n/2];
            int [][] A22 = new int[n/2][n/2];

            int [][] B11 = new int[n/2][n/2];
            int [][] B12 = new int[n/2][n/2];
            int [][] B21 = new int[n/2][n/2];
            int [][] B22 = new int[n/2][n/2];

            divideArray(A, A11, 0 , 0);
            divideArray(A, A12, 0 , n/2);
            divideArray(A, A21, n/2, 0);
            divideArray(A, A22, n/2, n/2);

            divideArray(B, B11, 0 , 0);
            divideArray(B, B12, 0 , n/2);
            divideArray(B, B21, n/2, 0);
            divideArray(B, B22, n/2, n/2);

            int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
            int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
            int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
            int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
            int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
            int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
            int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));

            int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
            int [][] C12 = addMatrices(P3, P5);
            int [][] C21 = addMatrices(P2, P4);
            int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);

            copySubArray(C11, result, 0 , 0);
            copySubArray(C12, result, 0 , n/2);
            copySubArray(C21, result, n/2, 0);
            copySubArray(C22, result, n/2, n/2);
        }

        return result;
    }

    public static int [][] addMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
        result[i][j] = A[i][j] + B[i][j];

        return result;
    }

    public static int [][] subtractMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                result[i][j] = A[i][j] - B[i][j];

        return result;
    }

    public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                child[i1][j1] = parent[i2][j2];
    }

    public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                parent[i2][j2] = child[i1][j1];
    }
}
class strassen2{

    public String getName () {
        return "Strassen(cached)";
    }

    static int [][] p1;
    static int [][] p2;
    static int [][] p3;
    static int [][] p4;
    static int [][] p5;
    static int [][] p6;
    static int [][] p7;
    static int [][] t0;
    static int [][] t1;

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        final int n = c.length;

        if (p1 == null || p1.length < n) {
            p1 = new int[n/2][n-1];
            p2 = new int[n/2][n-1];
            p3 = new int[n/2][n-1];
            p4 = new int[n/2][n-1];
            p5 = new int[n/2][n-1];
            p6 = new int[n/2][n-1];
            p7 = new int[n/2][n-1];
            t0 = new int[n/2][n-1];
            t1 = new int[n/2][n-1];
        }

        mult(c, a, b, 0, 0, n, 0);

        return c;
    }

    public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
        if(n == 1) {
            c[i0][j0] = a[i0][j0] * b[i0][j0];
        } else {
            final int nBy2 = n/2;

            final int i1 = i0 + nBy2;
            final int j1 = j0 + nBy2;

            // offset applied to 'p' j index so recursive calls don't overwrite data
            final int jp0 = offs;
            final int jp1 = nBy2 + offs;

            // P1 <- (A11 + A22)(B11 + B22)
            //  T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P2 <- (A21 + A22)B11
            //  T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0];
                    }
            }

            mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P3 <- A11(B12 - B22)
            //  T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
                }
            }

            mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P4 <- A22(B21 - B11)
            //  T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
                }
            }

            mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P5 <- (A11 + A12) B22
            //  T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j1];
                }
            }

            mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P6 <- (A21 - A11)(B11 - B12)
            //  T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
                }
            }

            mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P7 <- (A12 - A22)(B21 + B22)
            //  T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // combine
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    // C11 = P1 + P4 - P5 + P7;
                    c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
                    // C12 = P3 + P5;
                    c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
                    // C21 = P2 + P4;
                    c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
                    // C22 = P1 + P3 - P2 + P6;
                    c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
                }
            }
        }
    }

    void dumpInternal () {
        System.out.println("P1");
        TestIntMatrixMultiplication.dumpMatrix(p1);
        System.out.println("P2");
        TestIntMatrixMultiplication.dumpMatrix(p2);
        System.out.println("P3");
        TestIntMatrixMultiplication.dumpMatrix(p3);
        System.out.println("P4");
        TestIntMatrixMultiplication.dumpMatrix(p4);
        System.out.println("P5");
        TestIntMatrixMultiplication.dumpMatrix(p5);
        System.out.println("P6");
        TestIntMatrixMultiplication.dumpMatrix(p6);
        System.out.println("P7");
        TestIntMatrixMultiplication.dumpMatrix(p7);
        System.out.println("T0");
        TestIntMatrixMultiplication.dumpMatrix(t0);
        System.out.println("T1");
        TestIntMatrixMultiplication.dumpMatrix(t1);
    }
}


class Classical{
    public String getName () {
        return "classic";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        int n = a.length;

        for(int i=0; i<n; i++) {
            final int[] a_i = a[i];
            final int[] c_i = c[i];

            for(int j=0; j<n; j++) {
                int sum = 0;

                for(int k=0; k<n; k++) {
                    sum += a_i[k] * b[k][j];
                }

                c_i[j] = sum;
            }
        }

        return c;
    }
}
like image 631
Lokesh Khandelwal Avatar asked Feb 23 '23 23:02

Lokesh Khandelwal


1 Answers

Issues I see:

1)Your Strassen multiply is dynamically allocating memory all the time. This is going to kill performance.

2)Your Strassen multiply should switch over to conventional multiply for small sizes rather than being recursive all the way down (though this optimization sort of invalidates your test).

3)You're matrix size may simply be too small to see the difference.

You should do comparisons with several different sizes. Perhaps 256, 512, 1024, 2048, 4096, 8192... Then plot the times and look at the trends. You will probably want matrix size on a log scale if it's all powers of 2.

Strassen is only faster for large N. How large will depend a lot on the implementation. What you have done for classical is only a basic implementation and is not optimal on a modern machine either.

like image 178
phkahler Avatar answered Mar 04 '23 07:03

phkahler