I have a small performance bottleneck in an application that requires removing the non-diagonal elements from a large square matrix. So, the matrix x
17 24 1 8 15
23 5 7 14 16
4 6 13 20 22
10 12 19 21 3
11 18 25 2 9
becomes
17 0 0 0 0
0 5 0 0 0
0 0 13 0 0
0 0 0 21 0
0 0 0 0 9
Question: The bsxfun and diag solution below is the fastest solution so far, and I doubt I can improve it while still keeping the code in Matlab, but is there a faster way?
Here is what I thought of so far.
Perform element-wise multiplication by the identity matrix. This is the simplest solution:
y = x .* eye(n);
Using bsxfun
and diag
:
y = bsxfun(@times, diag(x), eye(n));
Lower/upper triangular matrices:
y = x - tril(x, -1) - triu(x, 1);
Various solutions using loops:
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
and
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
The bsxfun
solution is actually the fastest. This is my timing code:
function timing()
clear all
n = 5000;
x = rand(n, n);
f1 = @() tf1(x, n);
f2 = @() tf2(x, n);
f3 = @() tf3(x);
f4 = @() tf4(x, n);
f5 = @() tf5(x, n);
t1 = timeit(f1);
t2 = timeit(f2);
t3 = timeit(f3);
t4 = timeit(f4);
t5 = timeit(f5);
fprintf('t1: %f s\n', t1)
fprintf('t2: %f s\n', t2)
fprintf('t3: %f s\n', t3)
fprintf('t4: %f s\n', t4)
fprintf('t5: %f s\n', t5)
end
function y = tf1(x, n)
y = x .* eye(n);
end
function y = tf2(x, n)
y = bsxfun(@times, diag(x), eye(n));
end
function y = tf3(x)
y = x - tril(x, -1) - triu(x, 1);
end
function y = tf4(x, n)
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
end
function y = tf5(x, n)
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
end
which returns
t1: 0.111117 s
t2: 0.078692 s
t3: 0.219582 s
t4: 1.183389 s
t5: 1.198795 s
A diagonal matrix is a matrix in which all the elements other than the diagonal are zero. Hence, all the non diagonal elements of a diagonal matrix are 0.
x = diag( A ) returns a column vector of the main diagonal elements of A . x = diag( A , k ) returns a column vector of the elements on the k th diagonal of A .
If the entries in the matrix are all zero except the ones on the diagonals from lower left corner to the other upper side(right) corner are not zero, it is anti diagonal matrix.
Trace of the matrix is called sum of the elements in a principle diagonal of the square matrix.
I found that:
diag(diag(x))
is faster than bsxfun
. Similarly:
diag(x(1:size(x,1)+1:end))
is faster by more or less the same amount. playing with timeit
for x=rand(5000)
I got both faster than your bsxfun
by a factor of ~20.
EDIT:
This is on par with diag(diag(...
:
x2(n,n)=0;
x2(1:n+1:end)=x(1:n+1:end);
Note that the way I preallocate x2
is important, if you just use x2=zeros(n)
you'll get a slower solution. Read more about this in this discussion...
I didn't bother testing your various loop functions, since they were much slower in your implementation, but I tested the others, plus another method that I've used before:
y = diag(diag(x));
Here's the spoiler:
c1: 193.18 milliseconds // multiply by identity
c2: 102.16 milliseconds // bsxfun
c3: 342.24 milliseconds // tril and triu
c4: 6.03 milliseconds // call diag twice
It looks like two calls to diag
is by far the fastest on my machine.
Full timing code follows. I used my own benchmarking function rather than timeit
but the results should be comparable (and you can check them yourself).
>> x = randn(5000);
>> c1 = @() x .* eye(5000);
>> c2 = @() bsxfun(@times, diag(x), eye(5000));
>> c3 = @() x - tril(x,-1) - triu(x,1);
>> c4 = @() diag(diag(x));
>> benchmark.bench(c1)
Benchmarking @()x.*eye(5000)
Mean: 193.18 milliseconds, lb 191.94 milliseconds, ub 194.25 milliseconds, ci 95%
Stdev: 6.01 milliseconds, lb 3.27 milliseconds, ub 8.58 milliseconds, ci 95%
>> benchmark.bench(c2)
Benchmarking @()bsxfun(@times,diag(x),eye(5000))
Mean: 102.16 milliseconds, lb 100.83 milliseconds, ub 103.44 milliseconds, ci 95%
Stdev: 6.61 milliseconds, lb 6.04 milliseconds, ub 7.07 milliseconds, ci 95%
>> benchmark.bench(c3)
Benchmarking @()x-tril(x,-1)-triu(x,1)
Mean: 342.24 milliseconds, lb 340.28 milliseconds, ub 344.20 milliseconds, ci 95%
Stdev: 10.06 milliseconds, lb 8.85 milliseconds, ub 11.17 milliseconds, ci 95%
>> benchmark.bench(c4)
Benchmarking @()diag(diag(x))
Mean: 6.03 milliseconds, lb 5.96 milliseconds, ub 6.09 milliseconds, ci 95%
Stdev: 0.34 milliseconds, lb 0.27 milliseconds, ub 0.40 milliseconds, ci 95%
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With