diff options
author | Claudio Freire <klaussfreire@gmail.com> | 2018-01-31 00:50:57 -0800 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-01-31 01:01:40 -0800 |
commit | 3c05221dd443323e7948eadf8105c0b52e760f70 (patch) | |
tree | e5ae9c4255126b8c854ee281abe262d29fbe50af /numpy/lib/arraysetops.py | |
parent | bb7b12672fe7b68c7776a7f757741d4632001bf3 (diff) | |
download | numpy-3c05221dd443323e7948eadf8105c0b52e760f70.tar.gz |
MAINT: Remove messy handling of output tuple in np.unique
Largely taken from gh-9531
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r-- | numpy/lib/arraysetops.py | 64 |
1 files changed, 33 insertions, 31 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index b1e74dc74..78d4536c0 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -110,6 +110,14 @@ def ediff1d(ary, to_end=None, to_begin=None): return result +def _unpack_tuple(x): + """ Unpacks one-element tuples for use as return values """ + if len(x) == 1: + return x[0] + else: + return x + + def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ @@ -211,7 +219,9 @@ def unique(ar, return_index=False, return_inverse=False, """ ar = np.asanyarray(ar) if axis is None: - return _unique1d(ar, return_index, return_inverse, return_counts) + ret = _unique1d(ar, return_index, return_inverse, return_counts) + return _unpack_tuple(ret) + if not (-ar.ndim <= axis < ar.ndim): raise ValueError('Invalid axis kwarg specified for unique') @@ -245,11 +255,9 @@ def unique(ar, return_index=False, return_inverse=False, output = _unique1d(consolidated, return_index, return_inverse, return_counts) - if not (return_index or return_inverse or return_counts): - return reshape_uniq(output) - else: - uniq = reshape_uniq(output[0]) - return (uniq,) + output[1:] + output = (reshape_uniq(output[0]),) + output[1:] + return _unpack_tuple(output) + def _unique1d(ar, return_index=False, return_inverse=False, return_counts=False): @@ -259,19 +267,15 @@ def _unique1d(ar, return_index=False, return_inverse=False, ar = np.asanyarray(ar).flatten() optional_indices = return_index or return_inverse - optional_returns = optional_indices or return_counts if ar.size == 0: - if not optional_returns: - ret = ar - else: - ret = (ar,) - if return_index: - ret += (np.empty(0, np.intp),) - if return_inverse: - ret += (np.empty(0, np.intp),) - if return_counts: - ret += (np.empty(0, np.intp),) + ret = (ar,) + if return_index: + ret += (np.empty(0, np.intp),) + if return_inverse: + ret += (np.empty(0, np.intp),) + if return_counts: + ret += (np.empty(0, np.intp),) return ret if optional_indices: @@ -282,22 +286,20 @@ def _unique1d(ar, return_index=False, return_inverse=False, aux = ar flag = np.concatenate(([True], aux[1:] != aux[:-1])) - if not optional_returns: - ret = aux[flag] - else: - ret = (aux[flag],) - if return_index: - ret += (perm[flag],) - if return_inverse: - iflag = np.cumsum(flag) - 1 - inv_idx = np.empty(ar.shape, dtype=np.intp) - inv_idx[perm] = iflag - ret += (inv_idx,) - if return_counts: - idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) - ret += (np.diff(idx),) + ret = (aux[flag],) + if return_index: + ret += (perm[flag],) + if return_inverse: + iflag = np.cumsum(flag) - 1 + inv_idx = np.empty(ar.shape, dtype=np.intp) + inv_idx[perm] = iflag + ret += (inv_idx,) + if return_counts: + idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) + ret += (np.diff(idx),) return ret + def intersect1d(ar1, ar2, assume_unique=False): """ Find the intersection of two arrays. |