Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Enforcing compatibility between numpy 1.8 and 1.9 nansum?

Tags:

numpy

I have code that needs to behave identically independent of numpy version, but the underlying np.nansum function has changed behavior such that np.nansum([np.nan,np.nan]) is 0.0 in 1.9 and NaN in 1.8. The <=1.8 behavior is the one I would prefer, but the more important thing is that my code be robust against the numpy version.

The tricky thing is, the code applies an arbitrary numpy function (generally, a np.nan[something] function) to an ndarray. Is there any way to force the new or old numpy nan[something] functions to conform to the old or new behavior shy of monkeypatching them?

A possible solution I can think of is something like outarr[np.allnan(inarr, axis=axis)] = np.nan, but there is no np.allnan function - if this is the best solution, is the best implementation np.all(np.isnan(arr), axis=axis) (which would require only supporting np>=1.7, but that's probably OK)?

like image 994
keflavich Avatar asked Nov 10 '22 00:11

keflavich


1 Answers

In Numpy 1.8, nansum was defined as:

a, mask = _replace_nan(a, 0)

if mask is None:
    return np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
mask = np.all(mask, axis=axis, keepdims=keepdims)
tot = np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
if np.any(mask):
    tot = _copyto(tot, np.nan, mask)
    warnings.warn("In Numpy 1.9 the sum along empty slices will be zero.",
                  FutureWarning)
return tot

in Numpy 1.9, it is:

a, mask = _replace_nan(a, 0)
return np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims)

I don't think there is a way to make the new nansum behave the old way, but given that the original nansum code isn't that long, can you just include a copy of that code (without the warning) if you care about preserving the pre-1.8 behavior?

Note that _copyto can be imported numpy.lib.nanfunctions

like image 182
astrofrog Avatar answered Jan 04 '23 01:01

astrofrog