Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to gather elements of specific indices in numpy?

Tags:

numpy

I want to gather elements of specified indices in specified axis like following.

x = [[1,2,3], [4,5,6]]
index = [[2,1], [0, 1]]
x[:, index] = [[3, 2], [4, 5]]

This is essentially gather operation in pytorch, but as you know, this is not achievable in numpy this way. I am wondering if there is such a "gather" operation in numpy?

like image 746
ZEWEI CHU Avatar asked Oct 21 '17 21:10

ZEWEI CHU


People also ask

How do I extract an element from a NumPy array?

Using the logical_and() method The logical_and() method from the numpy package accepts multiple conditions or expressions as a parameter. Each of the conditions or the expressions should return a boolean value. These boolean values are used to extract the required elements from the array.

How do you get the indices of top n values in NumPy array?

For getting n-largest values from a NumPy array we have to first sort the NumPy array using numpy. argsort() function of NumPy then applying slicing concept with negative indexing. Return: [index_array, ndarray] Array of indices that sort arr along the specified axis.

How do you access individual elements of an array in Python?

We can access elements of an array using the index operator [] . All you need do in order to access a particular element is to call the array you created. Beside the array is the index [] operator, which will have the value of the particular element's index position from a given array.


2 Answers

numpy.take_along_axis is what I need, take elements according to the index. It can be used like gather method in PyTorch.

This is an example from the manual:

>>> a = np.array([[10, 30, 20], [60, 40, 50]])
>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai
array([[1],
       [0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[30],
       [60]])

like image 91
banma Avatar answered Sep 19 '22 00:09

banma


I wrote this awhile ago to replicate PyTorch's gather in Numpy. In this case self is your x

def gather(self, dim, index):
    """
    Gathers values along an axis specified by ``dim``.

    For a 3-D tensor the output is specified by:
        out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
        out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
        out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

    Parameters
    ----------
    dim:
        The axis along which to index
    index:
        A tensor of indices of elements to gather

    Returns
    -------
    Output Tensor
    """
    idx_xsection_shape = index.shape[:dim] + \
        index.shape[dim + 1:]
    self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
    if idx_xsection_shape != self_xsection_shape:
        raise ValueError("Except for dimension " + str(dim) +
                         ", all dimensions of index and self should be the same size")
    if index.dtype != np.dtype('int_'):
        raise TypeError("The values of index must be integers")
    data_swaped = np.swapaxes(self, 0, dim)
    index_swaped = np.swapaxes(index, 0, dim)
    gathered = np.choose(index_swaped, data_swaped)
    return np.swapaxes(gathered, 0, dim)

These are the test cases:

# Test 1
    t = np.array([[65, 17], [14, 25], [76, 22]])
    idx = np.array([[0], [1], [0]])
    dim = 1
    result = gather(t, dim=dim, index=idx)
    expected = np.array([[65], [25], [76]])
    print(np.array_equal(result, expected))

# Test 2
    t = np.array([[47, 74, 44], [56, 9, 37]])
    idx = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0]])
    dim = 0
    result = gather(t, dim=dim, index=idx)
    expected = np.array([[47, 74, 37], [56, 9, 44.], [47, 9, 44]])
    print(np.array_equal(result, expected))
like image 22
Sia Rezaei Avatar answered Sep 22 '22 00:09

Sia Rezaei