Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Getting the last non-nan index of a sorted numpy matrix or pandas dataframe

Given a numpy array (or pandas dataframe) like this:

import numpy as np

a = np.array([
[1,      1,      1,    0.5, np.nan, np.nan, np.nan],
[1,      1,      1, np.nan, np.nan, np.nan, np.nan],
[1,      1,      1,    0.5,   0.25,  0.125,  0.075],
[1,      1,      1,   0.25, np.nan, np.nan, np.nan],
[1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
[1,      1,    0.5,    0.5, np.nan, np.nan, np.nan]
])

I'm looking to most efficiently retrieve the last non-nan value in each row, so in this situation I'd be looking for a function that returns something like this:

np.array([3,
          2,
          6,
          3,
          0,
          3])

I can try np.argmin(a, axis=1) - 1, but this has at least two undesirable properties - it fails for rows not ending with nan (dealbreaker) and it doesn't "lazy-evaluate" and stop once it has reached the last non-nan value in a given row (this doesn't matter as much as the "it has to be right" condition).

I imagine there's a way to do it with np.where, but in addition to evaluating all the elements of each row, I can't see an obvious elegant way to rearrange the output to get the last index in each row:

>>> np.where(np.isnan(a))
(array([0, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5]),
 array([4, 5, 6, 3, 4, 5, 6, 4, 5, 6, 1, 2, 3, 4, 5, 6, 4, 5, 6]))
like image 360
Paul Avatar asked Dec 05 '22 16:12

Paul


1 Answers

This solution doesn't require the array to be sorted. It just returns the last non nan item along axis 1.

(~np.isnan(a)).cumsum(1).argmax(1)
like image 81
DougR Avatar answered Dec 28 '22 22:12

DougR