I need to perform in MATLAB the following calculation:
where w and v are vectors with N elements and A is a four-dimensional matrix (N^4 elements). This can be achieved by the following pedantic code:
N=10;
A=rand(N,N,N,N);
v=rand(N,1);
w=zeros(N,1);
for pp=1:N
for ll=1:N
for mm=1:N
for nn=1:N
w(pp)=w(pp)+A(pp,ll,mm,nn)*v(ll)*v(mm)*conj(v(nn));
end
end
end
end
which is extemely slow. Is there any way to vectorize this kind of sum in MATLAB?
Approach #1
With few reshape
's and matrix multiplication -
A1 = reshape(A,N^3,N)*conj(v)
A2 = reshape(A1,N^2,N)*v
w = reshape(A2,N,N)*v
Approach #2
With one bsxfun
, reshape
and matrix-multiplication
-
A1 = reshape(A,N^3,N)*conj(v)
vm = bsxfun(@times,v,v.')
w = reshape(A1,N,N^2)*vm(:)
This section compares runtimes for the two approaches listed in this post, first tested approach in Shai's post and original approach listed in the question.
Benchmarking Code
N=100;
A=rand(N,N,N,N);
v=rand(N,1);
disp('----------------------------------- With Original Approach')
tic
%// .... Code from the original post ...//
toc
disp('----------------------------------- With Shai Approach #1')
tic
s4 = sum( bsxfun( @times, A, permute( conj(v), [4 3 2 1] ) ), 4 );
s3 = sum( bsxfun( @times, s4, permute( v, [3 2 1] ) ), 3 );
w2 = s3*v;
toc
disp('----------------------------------- With Divakar Approach #1')
tic
A1 = reshape(A,N^3,N)*conj(v);
A2 = reshape(A1,N^2,N)*v;
w3 = reshape(A2,N,N)*v;
toc
disp('----------------------------------- With Divakar Approach #2')
tic
A1 = reshape(A,N^3,N)*conj(v);
vm = bsxfun(@times,v,v.');
w4 = reshape(A1,N,N^2)*vm(:);
toc
Runtime Results
----------------------------------- With Original Approach
Elapsed time is 4.604767 seconds.
----------------------------------- With Shai Approach #1
Elapsed time is 0.334667 seconds.
----------------------------------- With Divakar Approach #1
Elapsed time is 0.071905 seconds.
----------------------------------- With Divakar Approach #2
Elapsed time is 0.058877 seconds.
Conclusions
The second approach in this post seems to be giving about 80x
speedup over the original approach.
You can try using bsxfun.
Assuming v
is an N
-by-1 column vectors (otherwise, permutations should be modified a little bit).
% sum over n (4th dim)
s4 = sum( bsxfun( @times, A, permute( conj(v), [4 3 2 1] ) ), 4 );
Now the interim result is only N
-by-N
-by-N
.
% sum over m (3rd dim)
s3 = sum( bsxfun( @times, s4, permute( v, [3 2 1] ) ), 3 )
Continuing to the last sum
% sum over l (2nd dim)
w = s3*v;
Coming to think about it, have you considered using dot
in its multidim version? I did not test it, but it should work (maybe some minor corrections).
s4 = dot( A, permute( conj(v), [4 3 2 1] ), 4 );
s3 = dot( s4, permute( v, [3 2 1] ), 3 );
w = s3*v;
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With