Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Delete columns of a matrix that are mostly zero

Tags:

python

numpy

In this case, mostly means less than 5 elements are non-zero in a column. Matrix is a 2d ndarray.

Sample data:

a = np.array([[1,1,2,1,1],
              [1,1,0,1,0],
              [1,1,0,1,0],
              [1,1,0,3,0],
              [1,1,0,3,0],
              [1,1,1,5,3],
              [1,1,0,1,0],
              [1,1,0,1,0],
              [1,1,4,3,0],
              [1,1,0,4,0],
              [1,1,0,5,0],
              [1,1,0,0,0]])

Output

a = np.array([[1,1,1],
              [1,1,1],
              [1,1,1],
              [1,1,3],
              [1,1,3],
              [1,1,5],
              [1,1,1],
              [1,1,1],
              [1,1,3],
              [1,1,4],
              [1,1,5],
              [1,1,0]])
like image 486
siamii Avatar asked Feb 17 '23 16:02

siamii


1 Answers

How about:

>>> a[:, (a != 0).sum(axis=0) >= 5]
array([[1, 1, 1],
       [1, 1, 1],
       [1, 1, 1],
       [1, 1, 3],
       [1, 1, 3],
       [1, 1, 5],
       [1, 1, 1],
       [1, 1, 1],
       [1, 1, 3],
       [1, 1, 4],
       [1, 1, 5],
       [1, 1, 0]])

or

>>> a[:, np.apply_along_axis(np.count_nonzero, 0, a) >= 5]
array([[1, 1, 1],
       [1, 1, 1],
       [1, 1, 1],
       [1, 1, 3],
       [1, 1, 3],
       [1, 1, 5],
       [1, 1, 1],
       [1, 1, 1],
       [1, 1, 3],
       [1, 1, 4],
       [1, 1, 5],
       [1, 1, 0]])

In the past I've found np.count_nonzero to be much faster than the sum trick, but here -- probably because of the need to use np.appyly_along_axis -- that version is instead much slower, at least for this a. Some other tests showed the same even for larger matrices, but YMMV.

like image 119
DSM Avatar answered Feb 19 '23 11:02

DSM