diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-05-02 11:22:37 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-05-02 12:58:05 -0600 |
commit | 63ff998b5789a5aa7ce7830a258445544f1063ae (patch) | |
tree | 4aaab6abf03297f9f5860da46d25fc7081a757ba | |
parent | 602836b99c2a046574dd352def24980b2da9d9b2 (diff) | |
download | numpy-63ff998b5789a5aa7ce7830a258445544f1063ae.tar.gz |
MAINT: Refactor numpy.ma.where.
Use np._NoValue as default values of 'x' and 'y' in signature. That
allows None values to be used for compatibility with the unmasked
version of where and makes detection of the case when neither is
given easier.
Use properties *.data and *.mask instead of *._data and *._mask. That
may avoid some subtle problems.
Minor style cleanups.
-rw-r--r-- | numpy/ma/core.py | 64 |
1 files changed, 24 insertions, 40 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index b52dc0ab7..f0a97bd8c 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -29,7 +29,7 @@ from functools import reduce import numpy as np import numpy.core.umath as umath import numpy.core.numerictypes as ntypes -from numpy import ndarray, amax, amin, iscomplexobj, bool_ +from numpy import ndarray, amax, amin, iscomplexobj, bool_, _NoValue from numpy import array as narray from numpy.lib.function_base import angle from numpy.compat import getargspec, formatargspec, long, basestring @@ -6708,7 +6708,7 @@ size.__doc__ = np.size.__doc__ #####-------------------------------------------------------------------------- #---- --- Extra functions --- #####-------------------------------------------------------------------------- -def where(condition, *args, **kwargs): +def where(condition, x=_NoValue, y=_NoValue): """ Return a masked array with elements from x or y, depending on condition. @@ -6755,40 +6755,20 @@ def where(condition, *args, **kwargs): [6.0 -- 8.0]] """ + missing = (x is _NoValue, y is _NoValue).count(True) - if ((len(args) == 0) or (len(args) == 2)) and (len(kwargs) == 0): - pass - elif (len(args) == 1) and (len(kwargs) == 1): - if ("x" not in kwargs): - raise ValueError( - "Cannot provide `x` as an argument and a keyword argument." - ) - if ("y" not in kwargs): - raise ValueError( - "Must provide `y` as a keyword argument if not as an argument." - ) - args += (kwargs["y"],) - elif (len(args) == 0) and ((len(kwargs) == 0) or (len(kwargs) == 2)): - if (("x" not in kwargs) and ("y" not in kwargs) or - ("x" in kwargs) and ("y" in kwargs)): - raise ValueError( - "Must provide both `x` and `y` as keyword arguments or neither." - ) - args += (kwargs["x"], kwargs["y"],) - else: - raise ValueError( - "Only takes 3 arguments was provided %i arguments." % - len(args) + len(kwargs) - ) + if missing == 1: + raise ValueError("Must provide both 'x' and 'y' or neither.") + if missing == 2: + return filled(condition, 0).nonzero() + # Both x and y are provided - if len(args) == 0: - return filled(condition, 0).nonzero() - x, y = args - # Get the condition ............... + # Get the condition fc = filled(condition, 0).astype(MaskType) notfc = np.logical_not(fc) - # Get the data ...................................... + + # Get the data xv = getdata(x) yv = getdata(y) if x is masked: @@ -6797,20 +6777,24 @@ def where(condition, *args, **kwargs): ndtype = xv.dtype else: ndtype = np.find_common_type([xv.dtype, yv.dtype], []) + # Construct an empty array and fill it d = np.empty(fc.shape, dtype=ndtype).view(MaskedArray) - _data = d._data - np.copyto(_data, xv.astype(ndtype), where=fc) - np.copyto(_data, yv.astype(ndtype), where=notfc) + np.copyto(d._data, xv.astype(ndtype), where=fc) + np.copyto(d._data, yv.astype(ndtype), where=notfc) + # Create an empty mask and fill it - _mask = d._mask = np.zeros(fc.shape, dtype=MaskType) - np.copyto(_mask, getmask(x), where=fc) - np.copyto(_mask, getmask(y), where=notfc) - _mask |= getmaskarray(condition) - if not _mask.any(): - d._mask = nomask + mask = np.zeros(fc.shape, dtype=MaskType) + np.copyto(mask, getmask(x), where=fc) + np.copyto(mask, getmask(y), where=notfc) + mask |= getmaskarray(condition) + + # Use d._mask instead of d.mask to avoid copies + d._mask = mask if mask.any() else nomask + return d + def choose (indices, choices, out=None, mode='raise'): """ Use an index array to construct a new array from a set of choices. |