summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-05-02 11:22:37 -0600
committerCharles Harris <charlesr.harris@gmail.com>2015-05-02 12:58:05 -0600
commit63ff998b5789a5aa7ce7830a258445544f1063ae (patch)
tree4aaab6abf03297f9f5860da46d25fc7081a757ba
parent602836b99c2a046574dd352def24980b2da9d9b2 (diff)
downloadnumpy-63ff998b5789a5aa7ce7830a258445544f1063ae.tar.gz
MAINT: Refactor numpy.ma.where.
Use np._NoValue as default values of 'x' and 'y' in signature. That allows None values to be used for compatibility with the unmasked version of where and makes detection of the case when neither is given easier. Use properties *.data and *.mask instead of *._data and *._mask. That may avoid some subtle problems. Minor style cleanups.
-rw-r--r--numpy/ma/core.py64
1 files changed, 24 insertions, 40 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index b52dc0ab7..f0a97bd8c 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -29,7 +29,7 @@ from functools import reduce
import numpy as np
import numpy.core.umath as umath
import numpy.core.numerictypes as ntypes
-from numpy import ndarray, amax, amin, iscomplexobj, bool_
+from numpy import ndarray, amax, amin, iscomplexobj, bool_, _NoValue
from numpy import array as narray
from numpy.lib.function_base import angle
from numpy.compat import getargspec, formatargspec, long, basestring
@@ -6708,7 +6708,7 @@ size.__doc__ = np.size.__doc__
#####--------------------------------------------------------------------------
#---- --- Extra functions ---
#####--------------------------------------------------------------------------
-def where(condition, *args, **kwargs):
+def where(condition, x=_NoValue, y=_NoValue):
"""
Return a masked array with elements from x or y, depending on condition.
@@ -6755,40 +6755,20 @@ def where(condition, *args, **kwargs):
[6.0 -- 8.0]]
"""
+ missing = (x is _NoValue, y is _NoValue).count(True)
- if ((len(args) == 0) or (len(args) == 2)) and (len(kwargs) == 0):
- pass
- elif (len(args) == 1) and (len(kwargs) == 1):
- if ("x" not in kwargs):
- raise ValueError(
- "Cannot provide `x` as an argument and a keyword argument."
- )
- if ("y" not in kwargs):
- raise ValueError(
- "Must provide `y` as a keyword argument if not as an argument."
- )
- args += (kwargs["y"],)
- elif (len(args) == 0) and ((len(kwargs) == 0) or (len(kwargs) == 2)):
- if (("x" not in kwargs) and ("y" not in kwargs) or
- ("x" in kwargs) and ("y" in kwargs)):
- raise ValueError(
- "Must provide both `x` and `y` as keyword arguments or neither."
- )
- args += (kwargs["x"], kwargs["y"],)
- else:
- raise ValueError(
- "Only takes 3 arguments was provided %i arguments." %
- len(args) + len(kwargs)
- )
+ if missing == 1:
+ raise ValueError("Must provide both 'x' and 'y' or neither.")
+ if missing == 2:
+ return filled(condition, 0).nonzero()
+ # Both x and y are provided
- if len(args) == 0:
- return filled(condition, 0).nonzero()
- x, y = args
- # Get the condition ...............
+ # Get the condition
fc = filled(condition, 0).astype(MaskType)
notfc = np.logical_not(fc)
- # Get the data ......................................
+
+ # Get the data
xv = getdata(x)
yv = getdata(y)
if x is masked:
@@ -6797,20 +6777,24 @@ def where(condition, *args, **kwargs):
ndtype = xv.dtype
else:
ndtype = np.find_common_type([xv.dtype, yv.dtype], [])
+
# Construct an empty array and fill it
d = np.empty(fc.shape, dtype=ndtype).view(MaskedArray)
- _data = d._data
- np.copyto(_data, xv.astype(ndtype), where=fc)
- np.copyto(_data, yv.astype(ndtype), where=notfc)
+ np.copyto(d._data, xv.astype(ndtype), where=fc)
+ np.copyto(d._data, yv.astype(ndtype), where=notfc)
+
# Create an empty mask and fill it
- _mask = d._mask = np.zeros(fc.shape, dtype=MaskType)
- np.copyto(_mask, getmask(x), where=fc)
- np.copyto(_mask, getmask(y), where=notfc)
- _mask |= getmaskarray(condition)
- if not _mask.any():
- d._mask = nomask
+ mask = np.zeros(fc.shape, dtype=MaskType)
+ np.copyto(mask, getmask(x), where=fc)
+ np.copyto(mask, getmask(y), where=notfc)
+ mask |= getmaskarray(condition)
+
+ # Use d._mask instead of d.mask to avoid copies
+ d._mask = mask if mask.any() else nomask
+
return d
+
def choose (indices, choices, out=None, mode='raise'):
"""
Use an index array to construct a new array from a set of choices.