Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Intersection indices by rows

Given these two matrices:

m1 = [ 1 1;
       2 2;
       3 3;
       4 4;
       5 5 ];

m2 = [ 4 2;
       1 1;
       4 4;
       7 5 ];

I'm looking for a function, such as:

indices = GetIntersectionIndecies (m1,m2);

That the output of which will be

indices = 
          1
          0
          0
          1
          0

How can I find the intersection indices of rows between these two matrices without using a loop ?

like image 539
Sameh K. Mohamed Avatar asked Jan 15 '23 10:01

Sameh K. Mohamed


2 Answers

One possible solution:

function [Index] = GetIntersectionIndicies(m1, m2)
[~, I1] = intersect(m1, m2, 'rows');
Index = zeros(size(m1, 1), 1);
Index(I1) = 1;

By the way, I love the inventive solution of @Shai, and it is much faster than my solution if your matrices are small. But if your matrices are large, then my solution will dominate. This is because if we set T = size(m1, 1), then the tmp variable in the answer of @Shai will be T*T, ie a very large matrix if T is large. Here's some code for a quick speed test:

%# Set parameters
T = 1000;
M = 10;

%# Build test matrices
m1 = randi(5, T, 2);
m2 = randi(5, T, 2);

%# My solution
tic
for m = 1:M
[~, I1] = intersect(m1, m2, 'rows');
Index = zeros(size(m1, 1), 1);
Index(I1) = 1;
end
toc

%# @Shai solution
tic
for m = 1:M
tmp = bsxfun( @eq, permute( m1, [ 1 3 2 ] ), permute( m2, [ 3 1 2 ] ) );
tmp = all( tmp, 3 ); % tmp(i,j) is true iff m1(i,:) == m2(j,:)
imdices = any( tmp, 2 );
end
toc

Set T = 10 and M = 1000, and we get:

Elapsed time is 0.404726 seconds. %# My solution
Elapsed time is 0.017669 seconds. %# @Shai solution

But set T = 1000 and M = 100 and we get:

Elapsed time is 0.068831 seconds. %# My solution
Elapsed time is 0.508370 seconds. %# @Shai solution
like image 123
Colin T Bowers Avatar answered Jan 21 '23 16:01

Colin T Bowers


How about using bsxfun

function indices = GetIntersectionIndecies( m1, m2 )
    tmp = bsxfun( @eq, permute( m1, [ 1 3 2 ] ), permute( m2, [ 3 1 2 ] ) );
    tmp = all( tmp, 3 ); % tmp(i,j) is true iff m1(i,:) == m2(j,:)
    indices = any( tmp, 2 );
end

Cheers!

like image 42
Shai Avatar answered Jan 21 '23 16:01

Shai