Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

matlab : vectorize 4D matrix sum

I need to perform in MATLAB the following calculation:

where w and v are vectors with N elements and A is a four-dimensional matrix (N^4 elements). This can be achieved by the following pedantic code:

N=10;
A=rand(N,N,N,N);
v=rand(N,1);
w=zeros(N,1);

for pp=1:N
  for ll=1:N
    for mm=1:N
      for nn=1:N
        w(pp)=w(pp)+A(pp,ll,mm,nn)*v(ll)*v(mm)*conj(v(nn));
      end
    end
  end
end

which is extemely slow. Is there any way to vectorize this kind of sum in MATLAB?

like image 670
Juan Sebastian Totero Avatar asked Dec 24 '22 22:12

Juan Sebastian Totero


2 Answers

Approach #1

With few reshape's and matrix multiplication -

A1 = reshape(A,N^3,N)*conj(v)
A2 = reshape(A1,N^2,N)*v
w = reshape(A2,N,N)*v

Approach #2

With one bsxfun , reshape and matrix-multiplication -

A1 = reshape(A,N^3,N)*conj(v)
vm = bsxfun(@times,v,v.')
w = reshape(A1,N,N^2)*vm(:)

Benchmarking

This section compares runtimes for the two approaches listed in this post, first tested approach in Shai's post and original approach listed in the question.

Benchmarking Code

N=100;
A=rand(N,N,N,N);
v=rand(N,1);

disp('----------------------------------- With Original Approach')
tic
%// .... Code from the original post   ...//
toc

disp('----------------------------------- With Shai Approach #1')
tic
s4 = sum( bsxfun( @times, A, permute( conj(v), [4 3 2 1] ) ), 4 ); 
s3 = sum( bsxfun( @times, s4, permute( v, [3 2 1] ) ), 3 );
w2 = s3*v; 
toc

disp('----------------------------------- With Divakar Approach #1')
tic
A1 = reshape(A,N^3,N)*conj(v);
A2 = reshape(A1,N^2,N)*v;
w3 = reshape(A2,N,N)*v;
toc

disp('----------------------------------- With Divakar Approach #2')
tic
A1 = reshape(A,N^3,N)*conj(v);
vm = bsxfun(@times,v,v.');
w4 = reshape(A1,N,N^2)*vm(:);
toc

Runtime Results

----------------------------------- With Original Approach
Elapsed time is 4.604767 seconds.
----------------------------------- With Shai Approach #1
Elapsed time is 0.334667 seconds.
----------------------------------- With Divakar Approach #1
Elapsed time is 0.071905 seconds.
----------------------------------- With Divakar Approach #2
Elapsed time is 0.058877 seconds.

Conclusions

The second approach in this post seems to be giving about 80x speedup over the original approach.

like image 160
Divakar Avatar answered Dec 29 '22 00:12

Divakar


You can try using bsxfun.

Assuming v is an N-by-1 column vectors (otherwise, permutations should be modified a little bit).

% sum over n (4th dim)
s4 = sum( bsxfun( @times, A, permute( conj(v), [4 3 2 1] ) ), 4 ); 

Now the interim result is only N-by-N-by-N.

% sum over m (3rd dim)
s3 = sum( bsxfun( @times, s4, permute( v, [3 2 1] ) ), 3 )

Continuing to the last sum

% sum over l (2nd dim)
w = s3*v; 

Coming to think about it, have you considered using dot in its multidim version? I did not test it, but it should work (maybe some minor corrections).

s4 = dot( A, permute( conj(v), [4 3 2 1] ), 4 );
s3 = dot( s4, permute( v, [3 2 1] ), 3 );
w = s3*v;

like image 44
Shai Avatar answered Dec 29 '22 00:12

Shai