Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a fast(er) way to remove all non-diagonal elements from a square matrix?

Tags:

matrix

matlab

I have a small performance bottleneck in an application that requires removing the non-diagonal elements from a large square matrix. So, the matrix x

17    24     1     8    15
23     5     7    14    16
 4     6    13    20    22
10    12    19    21     3
11    18    25     2     9

becomes

17     0     0     0     0
 0     5     0     0     0
 0     0    13     0     0
 0     0     0    21     0
 0     0     0     0     9

Question: The bsxfun and diag solution below is the fastest solution so far, and I doubt I can improve it while still keeping the code in Matlab, but is there a faster way?

Solutions

Here is what I thought of so far.

Perform element-wise multiplication by the identity matrix. This is the simplest solution:

y = x .* eye(n);

Using bsxfun and diag:

y = bsxfun(@times, diag(x), eye(n));

Lower/upper triangular matrices:

y = x - tril(x, -1) - triu(x, 1);

Various solutions using loops:

y = x;
for ix=1:n
    for jx=1:n
        if ix ~= jx
            y(ix, jx) = 0;
        end
    end
end

and

y = x;
for ix=1:n
    for jx=1:ix-1
        y(ix, jx) = 0;
    end
    for jx=ix+1:n
        y(ix, jx) = 0;
    end
end

Timing

The bsxfun solution is actually the fastest. This is my timing code:

function timing()
clear all

n = 5000;
x = rand(n, n);

f1 = @() tf1(x, n);
f2 = @() tf2(x, n);
f3 = @() tf3(x);
f4 = @() tf4(x, n);
f5 = @() tf5(x, n);

t1 = timeit(f1);
t2 = timeit(f2);
t3 = timeit(f3);
t4 = timeit(f4);
t5 = timeit(f5);

fprintf('t1: %f s\n', t1)
fprintf('t2: %f s\n', t2)
fprintf('t3: %f s\n', t3)
fprintf('t4: %f s\n', t4)
fprintf('t5: %f s\n', t5)
end

function y = tf1(x, n)
y = x .* eye(n);
end


function y = tf2(x, n)
y = bsxfun(@times, diag(x), eye(n));
end


function y = tf3(x)
y = x - tril(x, -1) - triu(x, 1);
end


function y = tf4(x, n)
y = x;
for ix=1:n
    for jx=1:n
        if ix ~= jx
            y(ix, jx) = 0;
        end
    end
end
end


function y = tf5(x, n)
y = x;
for ix=1:n
    for jx=1:ix-1
        y(ix, jx) = 0;
    end
    for jx=ix+1:n
        y(ix, jx) = 0;
    end
end
end

which returns

t1: 0.111117 s
t2: 0.078692 s
t3: 0.219582 s
t4: 1.183389 s
t5: 1.198795 s
like image 363
Michael A Avatar asked Oct 04 '13 17:10

Michael A


People also ask

Which is a square matrix in which all non diagonal elements are zero?

A diagonal matrix is a matrix in which all the elements other than the diagonal are zero. Hence, all the non diagonal elements of a diagonal matrix are 0.

Which command is used to extract diagonal elements of a matrix?

x = diag( A ) returns a column vector of the main diagonal elements of A . x = diag( A , k ) returns a column vector of the elements on the k th diagonal of A .

What is the non main diagonal of a matrix called?

If the entries in the matrix are all zero except the ones on the diagonals from lower left corner to the other upper side(right) corner are not zero, it is anti diagonal matrix.

Is the sum of all the diagonal elements of a square matrix?

Trace of the matrix is called sum of the elements in a principle diagonal of the square matrix.


2 Answers

I found that:

diag(diag(x))

is faster than bsxfun. Similarly:

diag(x(1:size(x,1)+1:end))

is faster by more or less the same amount. playing with timeit for x=rand(5000) I got both faster than your bsxfun by a factor of ~20.

EDIT:

This is on par with diag(diag(...:

x2(n,n)=0;
x2(1:n+1:end)=x(1:n+1:end);

Note that the way I preallocate x2 is important, if you just use x2=zeros(n) you'll get a slower solution. Read more about this in this discussion...

like image 154
bla Avatar answered Sep 26 '22 15:09

bla


I didn't bother testing your various loop functions, since they were much slower in your implementation, but I tested the others, plus another method that I've used before:

y = diag(diag(x));

Here's the spoiler:

c1: 193.18 milliseconds  // multiply by identity
c2: 102.16 milliseconds  // bsxfun
c3: 342.24 milliseconds  // tril and triu
c4:   6.03 milliseconds  // call diag twice

It looks like two calls to diag is by far the fastest on my machine.

Full timing code follows. I used my own benchmarking function rather than timeit but the results should be comparable (and you can check them yourself).

>> x = randn(5000);

>> c1 = @() x .* eye(5000);
>> c2 = @() bsxfun(@times, diag(x), eye(5000));
>> c3 = @() x - tril(x,-1) - triu(x,1);
>> c4 = @() diag(diag(x));


>> benchmark.bench(c1)

Benchmarking @()x.*eye(5000)
   Mean: 193.18 milliseconds, lb 191.94 milliseconds, ub 194.25 milliseconds, ci 95%
  Stdev: 6.01 milliseconds, lb 3.27 milliseconds, ub 8.58 milliseconds, ci 95%

>> benchmark.bench(c2)

Benchmarking @()bsxfun(@times,diag(x),eye(5000))
   Mean: 102.16 milliseconds, lb 100.83 milliseconds, ub 103.44 milliseconds, ci 95%
  Stdev: 6.61 milliseconds, lb 6.04 milliseconds, ub 7.07 milliseconds, ci 95%

>> benchmark.bench(c3)

Benchmarking @()x-tril(x,-1)-triu(x,1)
   Mean: 342.24 milliseconds, lb 340.28 milliseconds, ub 344.20 milliseconds, ci 95%
  Stdev: 10.06 milliseconds, lb 8.85 milliseconds, ub 11.17 milliseconds, ci 95%

>> benchmark.bench(c4)

Benchmarking @()diag(diag(x))
   Mean: 6.03 milliseconds, lb 5.96 milliseconds, ub 6.09 milliseconds, ci 95%
  Stdev: 0.34 milliseconds, lb 0.27 milliseconds, ub 0.40 milliseconds, ci 95%
like image 29
Chris Taylor Avatar answered Sep 24 '22 15:09

Chris Taylor