summaryrefslogtreecommitdiff
path: root/numpy/lib/arraysetops.py
diff options
context:
space:
mode:
authorClaudio Freire <klaussfreire@gmail.com>2018-01-31 00:50:57 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-01-31 01:01:40 -0800
commit3c05221dd443323e7948eadf8105c0b52e760f70 (patch)
treee5ae9c4255126b8c854ee281abe262d29fbe50af /numpy/lib/arraysetops.py
parentbb7b12672fe7b68c7776a7f757741d4632001bf3 (diff)
downloadnumpy-3c05221dd443323e7948eadf8105c0b52e760f70.tar.gz
MAINT: Remove messy handling of output tuple in np.unique
Largely taken from gh-9531
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r--numpy/lib/arraysetops.py64
1 files changed, 33 insertions, 31 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index b1e74dc74..78d4536c0 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -110,6 +110,14 @@ def ediff1d(ary, to_end=None, to_begin=None):
return result
+def _unpack_tuple(x):
+ """ Unpacks one-element tuples for use as return values """
+ if len(x) == 1:
+ return x[0]
+ else:
+ return x
+
+
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis=None):
"""
@@ -211,7 +219,9 @@ def unique(ar, return_index=False, return_inverse=False,
"""
ar = np.asanyarray(ar)
if axis is None:
- return _unique1d(ar, return_index, return_inverse, return_counts)
+ ret = _unique1d(ar, return_index, return_inverse, return_counts)
+ return _unpack_tuple(ret)
+
if not (-ar.ndim <= axis < ar.ndim):
raise ValueError('Invalid axis kwarg specified for unique')
@@ -245,11 +255,9 @@ def unique(ar, return_index=False, return_inverse=False,
output = _unique1d(consolidated, return_index,
return_inverse, return_counts)
- if not (return_index or return_inverse or return_counts):
- return reshape_uniq(output)
- else:
- uniq = reshape_uniq(output[0])
- return (uniq,) + output[1:]
+ output = (reshape_uniq(output[0]),) + output[1:]
+ return _unpack_tuple(output)
+
def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False):
@@ -259,19 +267,15 @@ def _unique1d(ar, return_index=False, return_inverse=False,
ar = np.asanyarray(ar).flatten()
optional_indices = return_index or return_inverse
- optional_returns = optional_indices or return_counts
if ar.size == 0:
- if not optional_returns:
- ret = ar
- else:
- ret = (ar,)
- if return_index:
- ret += (np.empty(0, np.intp),)
- if return_inverse:
- ret += (np.empty(0, np.intp),)
- if return_counts:
- ret += (np.empty(0, np.intp),)
+ ret = (ar,)
+ if return_index:
+ ret += (np.empty(0, np.intp),)
+ if return_inverse:
+ ret += (np.empty(0, np.intp),)
+ if return_counts:
+ ret += (np.empty(0, np.intp),)
return ret
if optional_indices:
@@ -282,22 +286,20 @@ def _unique1d(ar, return_index=False, return_inverse=False,
aux = ar
flag = np.concatenate(([True], aux[1:] != aux[:-1]))
- if not optional_returns:
- ret = aux[flag]
- else:
- ret = (aux[flag],)
- if return_index:
- ret += (perm[flag],)
- if return_inverse:
- iflag = np.cumsum(flag) - 1
- inv_idx = np.empty(ar.shape, dtype=np.intp)
- inv_idx[perm] = iflag
- ret += (inv_idx,)
- if return_counts:
- idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
- ret += (np.diff(idx),)
+ ret = (aux[flag],)
+ if return_index:
+ ret += (perm[flag],)
+ if return_inverse:
+ iflag = np.cumsum(flag) - 1
+ inv_idx = np.empty(ar.shape, dtype=np.intp)
+ inv_idx[perm] = iflag
+ ret += (inv_idx,)
+ if return_counts:
+ idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
+ ret += (np.diff(idx),)
return ret
+
def intersect1d(ar1, ar2, assume_unique=False):
"""
Find the intersection of two arrays.