Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Replacing values greater than a limit in a numpy array

I have an array n x m, and maximum values for each column. What's the best way to replace values greater than the maximum, besides checking each element?

For example:

def check_limits(bad_array, maxs):
    good_array = np.copy(bad_array)
    for i_line in xrange(bad_array.shape[0]):
        for i_column in xrange(bad_array.shape[1]):
            if good_array[i_line][i_column] >= maxs[i_column]:
                good_array[i_line][i_column] = maxs[i_column] - 1
    return good_array

Anyway to do this faster and in a more concise way?

like image 344
ryanjg Avatar asked Jan 20 '23 03:01

ryanjg


2 Answers

Use putmask:

import numpy as np

a = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])
m = np.array([7,6,5,4])

# This is what you need:

np.putmask(a, a >= m, m - 1)

# a is now:

np.array([[0, 1, 2, 3],
          [4, 5, 4, 3],
          [6, 5, 4, 3]])
like image 95
eumiro Avatar answered Jan 29 '23 05:01

eumiro


Another way is to use the clip function:

using eumiro's example:

bad_array = np.array([[ 0,  1,  2,  3],
                      [ 4,  5,  6,  7],
                      [ 8,  9, 10, 11]])
maxs = np.array([7,6,5,4])

good_array = bad_array.clip(max=maxs-1)

OR

bad_array.clip(max=maxs-1, out=good_array)

you can also specify the lower limit, by adding the argument min=

like image 39
Ali Abul Hawa Avatar answered Jan 29 '23 03:01

Ali Abul Hawa