Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does numpy's argpartition work on the documentation's example?

Tags:

python

numpy

I am trying to understand numpy's argpartition function. I have made the documentation's example as basic as possible.

import numpy as np

x = np.array([3, 4, 2, 1])
print("x: ", x)

a=np.argpartition(x, 3)
print("a: ", a)

print("x[a]:", x[a])

This is the output...

('x: ', array([3, 4, 2, 1]))
('a: ', array([2, 3, 0, 1]))
('x[a]:', array([2, 1, 3, 4]))

In the line a=np.argpartition(x, 3) isn't the kth element the last element (the number 1)? If it is number 1, when x is sorted shouldn't 1 become the first element (element 0)?

In x[a], why is 2 the first element "in front" of 1?

What fundamental thing am I missing?

like image 644
Mel Avatar asked Sep 23 '18 10:09

Mel


3 Answers

The more complete answer to what argpartition does is in the documentation of partition, and that one says:

Creates a copy of the array with its elements rearranged in such a way that the value of the element in k-th position is in the position it would be in a sorted array. All elements smaller than the k-th element are moved before this element and all equal or greater are moved behind it. The ordering of the elements in the two partitions is undefined.

So, for the input array 3, 4, 2, 1, the sorted array would be 1, 2, 3, 4.

The result of np.partition([3, 4, 2, 1], 3) will have the correct value (i.e. same as sorted array) in the 3rd (i.e. last) element. The correct value for the 3rd element is 4.

Let me show this for all values of k to make it clear:

  • np.partition([3, 4, 2, 1], 0) - [1, 4, 2, 3]
  • np.partition([3, 4, 2, 1], 1) - [1, 2, 4, 3]
  • np.partition([3, 4, 2, 1], 2) - [1, 2, 3, 4]
  • np.partition([3, 4, 2, 1], 3) - [2, 1, 3, 4]

In other words: the k-th element of the result is the same as the k-th element of the sorted array. All elements before k are smaller than or equal to that element. All elements after it are greater than or equal to it.

The same happens with argpartition, except argpartition returns indices which can then be used for form the same result.

like image 109
zvone Avatar answered Nov 18 '22 21:11

zvone


Similar to @Imtinan, I struggled with this. I found it useful to break up the function into the arg and the partition.

Take the following array:

array = np.array([9, 2, 7, 4, 6, 3, 8, 1, 5])

the corresponding indices are: [0,1,2,3,4,5,6,7,8] where 8th index = 5 and 0th = 9

if we do np.partition(array, k=5), the code is going to take the 5th element (not index) and then place it into a new array. It is then going to put those elements < 5th element before it and that > 5th element after, like this:

pseudo output: [lower value elements, 5th element, higher value elements]

if we compute this we get:

array([3, 5, 1, 4, 2, 6, 8, 7, 9])

This makes sense as the 5th element in the original array = 6, [1,2,3,4,5] are all lower than 6 and [7,8,9] are higher than 6. Note that the elements are not ordered.

The arg part of the np.argpartition() then goes one step further and swaps the elements out for their respective indices in the original array. So if we did:

np.argpartition(array, 5) we will get:

array([5, 8, 7, 3, 1, 4, 6, 2, 0])

from above, the original array had this structure [index=value] [0=9, 1=2, 2=7, 3=4, 4=6, 5=3, 6=8, 7=1, 8=5]

you can map the value of the index to the output and you with satisfy the condition:

argpartition() = partition(), like this:

[index form] array([5, 8, 7, 3, 1, 4, 6, 2, 0]) becomes

[3, 5, 1, 4, 2, 6, 8, 7, 9]

which is the same as the output of np.partition(array),

array([3, 5, 1, 4, 2, 6, 8, 7, 9])

Hopefully, this makes sense, it was the only way I could get my head around the arg part of the function.

like image 20
DSeal6 Avatar answered Nov 18 '22 21:11

DSeal6


i remember having a hard time figuring it out too, maybe the documentation is written badly but this is what it means

When you do a=np.argpartition(x, 3) then x is sorted in such a way that only the element at the k'th index will be sorted (in our case k=3)

So when you run this code basically you are asking what would the value of the 3rd index be in a sorted array. Hence the output is ('x[a]:', array([2, 1, 3, 4]))where only element 3 is sorted.

As the document suggests all numbers smaller than the kth element are before it (in no particular order) hence you get 2 before 1, since its no particular order.

i hope this clarifies it, if you are still confused then feel free to comment :)

like image 7
Imtinan Azhar Avatar answered Nov 18 '22 21:11

Imtinan Azhar