Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

General method to find submatrix in matlab matrix

I am looking for a 'good' way to find a matrix (pattern) in a larger matrix (arbitrary number of dimensions).

Example:

total = rand(3,4,5);
sub = total(2:3,1:3,3:4);

Now I want this to happen:

loc = matrixFind(total, sub)

In this case loc should become [2 1 3].

For now I am just interested in finding one single point (if it exists) and am not worried about rounding issues. It can be assumed that sub 'fits' in total.


Here is how I could do it for 3 dimensions, however it just feels like there is a better way:

total = rand(3,4,5);
sub = total(2:3,1:3,3:4);
loc = [];
for x = 1:size(total,1)-size(sub,1)+1
    for y = 1:size(total,2)-size(sub,2)+1
        for z = 1:size(total,3)-size(sub,3)+1
            block = total(x:x+size(sub,1)-1,y:y+size(sub,2)-1,z:z+size(sub,3)-1);
            if isequal(sub,block)
                loc = [x y z]
            end
        end
    end
end

I hope to find a workable solution for an arbitrary number of dimensions.

like image 363
Dennis Jaheruddin Avatar asked Sep 16 '13 14:09

Dennis Jaheruddin


People also ask

How do you find the submatrix of a matrix?

Step 1 − Create a DP matrix of size (n+1)*(n+1). Step 2 − For each element of the matrix, find the sum till the current index. Step 3 − For all indexes from 0 to n, find the sum of sub-matrix of size size*size. Using the formula and store in currSum.

What does submatrix mean?

(ˈsʌbˌmeɪtrɪks ) noun. a matrix formed from parts of a larger matrix.


2 Answers

Here is low-performance, but (supposedly) arbitrary dimensional function. It uses find to create a list of (linear) indices of potential matching positions in total and then just checks if the appropriately sized subblock of total matches sub.

function loc = matrixFind(total, sub)
%matrixFind find position of array in another array

    % initialize result
    loc = [];

    % pre-check: do all elements of sub exist in total?
    elements_in_both = intersect(sub(:), total(:));
    if numel(elements_in_both) < numel(unique(sub))
        % if not, return nothing
        return
    end

    % select a pivot element
    % Improvement: use least common element in total for less iterations
    pivot_element = sub(1);

    % determine linear index of all occurences of pivot_elemnent in total
    starting_positions = find(total == pivot_element);

    % prepare cell arrays for variable length subscript vectors
    [subscripts, subscript_ranges] = deal(cell([1, ndims(total)]));


    for k = 1:length(starting_positions)
        % fill subscript vector for starting position
        [subscripts{:}] = ind2sub(size(total), starting_positions(k));

        % add offsets according to size of sub per dimension
        for m = 1:length(subscripts)
            subscript_ranges{m} = subscripts{m}:subscripts{m} + size(sub, m) - 1;
        end

        % is subblock of total equal to sub
        if isequal(total(subscript_ranges{:}), sub)
            loc = [loc; cell2mat(subscripts)]; %#ok<AGROW>
        end
    end
end
like image 89
ojdo Avatar answered Oct 11 '22 23:10

ojdo


This is based on doing all possible shifts of the original matrix total and comparing the upper-leftmost-etc sub-matrix of the shifted total with the sought pattern subs. Shifts are generated using strings, and are applied using circshift.

Most of the work is done vectorized. Only one level of loops is used.

The function finds all matchings, not just the first. For example:

>> total = ones(3,4,5,6);
>> sub = ones(3,3,5,6);
>> matrixFind(total, sub)
ans =

     1     1     1     1
     1     2     1     1

Here is the function:

function sol = matrixFind(total, sub)

nd = ndims(total);
sizt = size(total).';
max_sizt = max(sizt);
sizs = [ size(sub) ones(1,nd-ndims(sub)) ].'; % in case there are
% trailing singletons

if any(sizs>sizt)
    error('Incorrect dimensions')
end

allowed_shift = (sizt-sizs);
max_allowed_shift = max(allowed_shift);
if max_allowed_shift>0
    shifts = dec2base(0:(max_allowed_shift+1)^nd-1,max_allowed_shift+1).'-'0';
    filter = all(bsxfun(@le,shifts,allowed_shift));
    shifts = shifts(:,filter); % possible shifts of matrix "total", along 
    % all dimensions
else
    shifts = zeros(nd,1);
end

for dim = 1:nd
    d{dim} = 1:sizt(dim); % vectors with subindices per dimension
end
g = cell(1,nd);
[g{:}] = ndgrid(d{:}); % grid of subindices per dimension
gc = cat(nd+1,g{:}); % concatenated grid
accept = repmat(permute(sizs,[2:nd+1 1]), [sizt; 1]); % acceptable values
% of subindices in order to compare with matrix "sub"
ind_filter = find(all(gc<=accept,nd+1));

sol = [];
for shift = shifts
    total_shifted = circshift(total,-shift);
    if all(total_shifted(ind_filter)==sub(:))
        sol = [ sol; shift.'+1 ];
    end
end
like image 24
Luis Mendo Avatar answered Oct 11 '22 23:10

Luis Mendo