Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy filter using condition on each element

Tags:

python

numpy

I have a filter expression as follows:

feasible_agents = filter(lambda agent: agent >= cost[task, agent], agents)

where agents is a python list.

Now, to get speedup, I am trying to implement this using numpy.

What would be the equivalent using numpy?

I know that this works:

threshold = 5.0
feasible_agents = np_agents[np_agents > threshold]

where np_agents is the numpy equivalent of agents.

However, I want threshold to be a function of each element in the numpy array.

like image 495
Niloy Saha Avatar asked Sep 18 '18 07:09

Niloy Saha


2 Answers

You can use numpy.extract:

>>> nparr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
>>> nparreven = np.extract(nparr % 2 == 0, nparr)

or numpy.where:

>>> nparr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
>>> nparreven = nparr[np.where(nparr % 2 == 0)]
like image 180
unlut Avatar answered Sep 23 '22 03:09

unlut


Since you don't provide an example data, use toy data:

# Cost of agents represented by indices of cost, we have agents 0, 1, 2, 3
cost = np.array([4,5,6,2])
# Agents to consider 
np_agents = np.array([0,1,3])
# threshold for each agent. Calculate different thresholds for different agents. Use array of indexes np_agents into cost array.
thresholds = cost[np_agents] # np.array([4,5,2])
feasible_agents = np_agents[np_agents > thresholds] # np.array([3])
like image 26
Deepak Saini Avatar answered Sep 23 '22 03:09

Deepak Saini