Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

matlab/octave - Generalized matrix multiplication

I would like to do a function to generalize matrix multiplication. Basically, it should be able to do the standard matrix multiplication, but it should allow to change the two binary operators product/sum by any other function.

The goal is to be as efficient as possible, both in terms of CPU and memory. Of course, it will always be less efficient than A*B, but the operators flexibility is the point here.

Here are a few commands I could come up after reading various interesting threads:

A = randi(10, 2, 3);
B = randi(10, 3, 4);

% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))

% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)

% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)

% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd

The problem with methods 1-3 are that they will generate n matrices before collapsing them using sum(). 4 is better because it does the sum() inside the bsxfun, but bsxfun still generates n matrices (except that they are mostly empty, containing only a vector of non-zeros values being the sums, the rest is filled with 0 to match the dimensions requirement).

What I would like is something like the 4th method but without the useless 0 to spare memory.

Any idea?

like image 711
gaborous Avatar asked Jun 16 '14 13:06

gaborous


3 Answers

Here is a slightly more polished version of the solution you posted, with some small improvements.

We check if we have more rows than columns or the other way around, and then do the multiplication accordingly by choosing either to multiply rows with matrices or matrices with columns (thus doing the least amount of loop iterations).

A*B

Note: This may not always be the best strategy (going by rows instead of by columns) even if there are less rows than columns; the fact that MATLAB arrays are stored in a column-major order in memory makes it more efficient to slice by columns, as the elements are stored consecutively. Whereas accessing rows involves traversing elements by strides (which is not cache-friendly -- think spatial locality).

Other than that, the code should handle double/single, real/complex, full/sparse (and errors where it is not a possible combination). It also respects empty matrices and zero-dimensions.

function C = my_mtimes(A, B, outFcn, inFcn)
    % default arguments
    if nargin < 4, inFcn = @times; end
    if nargin < 3, outFcn = @sum; end

    % check valid input
    assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
    assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
    assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
        'Expecting function handles.')

    % preallocate output matrix
    M = size(A,1);
    N = size(B,2);
    if issparse(A)
        args = {'like',A};
    elseif issparse(B)
        args = {'like',B};
    else
        args = {superiorfloat(A,B)};
    end
    C = zeros(M,N, args{:});

    % compute matrix multiplication
    % http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
    if M < N
        % concatenation of products of row vectors with matrices
        % A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
        for m=1:M
            %C(m,:) = A(m,:) * B;
            %C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
            C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
        end
    else
        % concatenation of products of matrices with column vectors
        % A*B = [A*b_1 , A*b_2 , ... , A*b_n]
        for n=1:N
            %C(:,n) = A * B(:,n);
            %C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
            C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
        end
    end
end

Comparison

The function is no doubt slower throughout, but for larger sizes it is orders of magnitude worse than the built-in matrix-multiplication:

        (tic/toc times in seconds)
      (tested in R2014a on Windows 8)

    size      mtimes       my_mtimes 
    ____    __________     _________
     400     0.0026398       0.20282
     600      0.012039       0.68471
     800      0.014571        1.6922
    1000      0.026645        3.5107
    2000       0.20204         28.76
    4000        1.5578        221.51

mtimes_vs_mymtimes

Here is the test code:

sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
    n = sz(i); disp(n)
    A = rand(n,n);
    B = rand(n,n);

    tic
    C = A*B;
    t(i,1) = toc;
    tic
    D = my_mtimes(A,B);
    t(i,2) = toc;

    assert(norm(C-D) < 1e-6)
    clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight

Extra

For completeness, below are two more naive ways to implement the generalized matrix multiplication (if you want to compare the performance, replace the last part of the my_mtimes function with either of these). I'm not even gonna bother posting their elapsed times :)

C = zeros(M,N, args{:});
for m=1:M
    for n=1:N
        %C(m,n) = A(m,:) * B(:,n);
        %C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
        C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
    end
end

And another way (with a triple-loop):

C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
    for n=1:N
        for p=1:P
            %C(m,n) = C(m,n) + A(m,p)*B(p,n);
            %C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
            C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
        end
    end
end

What to try next?

If you want to squeeze out more performance, you're gonna have to move to a C/C++ MEX-file to cut down on the overhead of interpreted MATLAB code. You can still take advantage of optimized BLAS/LAPACK routines by calling them from MEX-files (see the second part of this post for an example). MATLAB ships with Intel MKL library which frankly you cannot beat when it comes to linear algebra computations on Intel processors.

Others have already mentioned a couple of submissions on the File Exchange that implement general-purpose matrix routines as MEX-files (see @natan's answer). Those are especially effective if you link them against an optimized BLAS library.

like image 158
Amro Avatar answered Oct 20 '22 06:10

Amro


Why not just exploit bsxfun's ability to accept an arbitrary function?

C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1);

Here

  • f is the outer function (corrresponding to sum in the matrix-multiplication case). It should accept a 3D array of arbitrary size mxnxp and operate along its columns to return a 1xmxp array.
  • g is the inner function (corresponding to product in the matrix-multiplication case). As per bsxfun, it should accept as input either two column vectors of the same size, or one column vector and one scalar, and return as output a column vector of the same size as the input(s).

This works in Matlab. I haven't tested in Octave.


Example 1: Matrix-multiplication:

>> f = @sum;   %// outer function: sum
>> g = @times; %// inner function: product
>> A = [1 2 3; 4 5 6];
>> B = [10 11; -12 -13; 14 15];
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
    28    30
    64    69

Check:

>> A*B
ans =
    28    30
    64    69

Example 2: Consider the above two matrices with

>> f = @(x,y) sum(abs(x));     %// outer function: sum of absolute values
>> g = @(x,y) max(x./y, y./x); %// inner function: "symmetric" ratio
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
   14.8333   16.1538
    5.2500    5.6346

Check: manually compute C(1,2):

>> sum(abs( max( (A(1,:))./(B(:,2)).', (B(:,2)).'./(A(1,:)) ) ))
ans =
   16.1538
like image 41
Luis Mendo Avatar answered Oct 20 '22 07:10

Luis Mendo


Without diving into the details, there are tools such as mtimesx and MMX that are fast general purpose matrix and scalar operations routines. You can look into their code and adapt them to your needs. It would most likely be faster than matlab's bsxfun.

like image 28
bla Avatar answered Oct 20 '22 05:10

bla