Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Determine sum of numpy array while excluding certain values

I would like to determine the sum of a two dimensional numpy array. However, elements with a certain value I want to exclude from this summation. What is the most efficient way to do this?

For example, here I initialize a two dimensional numpy array of 1s and replace several of them by 2:

import numpy

data_set = numpy.ones((10, 10))

data_set[4][4] = 2
data_set[5][5] = 2
data_set[6][6] = 2

How can I sum over the elements in my two dimensional array while excluding all of the 2s? Note that with the 10 by 10 array the correct answer should be 97 as I replaced three elements with the value 2.

I know I can do this with nested for loops. For example:

elements = []
for idx_x in range(data_set.shape[0]):
  for idx_y in range(data_set.shape[1]):
    if data_set[idx_x][idx_y] != 2:
      elements.append(data_set[idx_x][idx_y])

data_set_sum = numpy.sum(elements)

However on my actual data (which is very large) this is too slow. What is the correct way of doing this?

like image 266
The Dude Avatar asked Dec 02 '22 19:12

The Dude


2 Answers

Use numpy's capability of indexing with boolean arrays. In the below example data_set!=2 evaluates to a boolean array which is True whenever the element is not 2 (and has the correct shape). So data_set[data_set!=2] is a fast and convenient way to get an array which doesn't contain a certain value. Of course, the boolean expression can be more complex.

In [1]: import numpy as np
In [2]: data_set = np.ones((10, 10))
In [4]: data_set[4,4] = 2
In [5]: data_set[5,5] = 2
In [6]: data_set[6,6] = 2
In [7]: data_set[data_set != 2].sum()
Out[7]: 97.0
In [8]: data_set != 2
Out[8]: 
array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       ...
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]], dtype=bool)
like image 172
Korem Avatar answered Apr 09 '23 06:04

Korem


Without numpy, the solution is not much more complex:

x = [1,2,3,4,5,6,7]
sum(y for y in x if y != 7)
# 21

Works for a list of excluded values too:

# set is faster for resolving `in`
exl = set([1,2,3])
sum(y for y in x if y not in exl)
# 22
like image 21
njzk2 Avatar answered Apr 09 '23 07:04

njzk2