Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how do you find and save duplicated rows in a numpy array?

Tags:

python

numpy

rows

I have an array e.g.

Array = [[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[1,1,1],[2,2,2]]

And i would like something that would output the following:

Repeated = [[1,1,1],[2,2,2]]

Preserving the number of repeated rows would work too, e.g.

Repeated = [[1,1,1],[1,1,1],[2,2,2],[2,2,2]]

I thought the solution might include numpy.unique, but i can't get it to work, is there a native python / numpy function?

like image 987
Ben Bird Avatar asked Jan 04 '18 16:01

Ben Bird


People also ask

How do you delete duplicate rows in NumPy array?

The unique() method is a built-in method in the numpy, that takes an array as input and return a unique array i.e by removing all the duplicate elements. In order to remove duplicates we will pass the given NumPy array to the unique() method and it will return the unique array.

Which of the following is a possible way to find unique rows in a NumPy array?

To find unique rows in a NumPy array we are using numpy. unique() function of NumPy library.

How do you repeat rows in NumPy?

In Python, if you want to repeat the elements multiple times in the NumPy array then you can use the numpy. repeat() function. In Python, this method is available in the NumPy module and this function is used to return the numpy array of the repeated items along with axis such as 0 and 1.


2 Answers

Using the new axis functionality of np.unique alongwith return_counts=True that gives us the unique rows and the corresponding counts for each of those rows, we can mask out the rows with counts > 1 and thus have our desired output, like so -

In [688]: a = np.array([[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[1,1,1],[2,2,2]])

In [689]: unq, count = np.unique(a, axis=0, return_counts=True)

In [690]: unq[count>1]
Out[690]: 
array([[1, 1, 1],
       [2, 2, 2]])
like image 189
Divakar Avatar answered Oct 07 '22 03:10

Divakar


If you need to get indices of the repeated rows

import numpy as np

a = np.array([[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[1,1,1],[2,2,2]])
unq, count = np.unique(a, axis=0, return_counts=True)
repeated_groups = unq[count > 1]

for repeated_group in repeated_groups:
    repeated_idx = np.argwhere(np.all(a == repeated_group, axis=1))
    print(repeated_idx.ravel())

# [0 5]
# [1 6]
like image 5
v.grabovets Avatar answered Oct 07 '22 03:10

v.grabovets