diff options
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r-- | numpy/lib/arraysetops.py | 120 |
1 files changed, 60 insertions, 60 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index b1e74dc74..e8eda297f 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -110,16 +110,25 @@ def ediff1d(ary, to_end=None, to_begin=None): return result +def _unpack_tuple(x): + """ Unpacks one-element tuples for use as return values """ + if len(x) == 1: + return x[0] + else: + return x + + def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ Find the unique elements of an array. Returns the sorted unique elements of an array. There are three optional - outputs in addition to the unique elements: the indices of the input array - that give the unique values, the indices of the unique array that - reconstruct the input array, and the number of times each unique value - comes up in the input array. + outputs in addition to the unique elements: + + * the indices of the input array that give the unique values + * the indices of the unique array that reconstruct the input array + * the number of times each unique value comes up in the input array Parameters ---------- @@ -139,16 +148,15 @@ def unique(ar, return_index=False, return_inverse=False, .. versionadded:: 1.9.0 axis : int or None, optional - The axis to operate on. If None, `ar` will be flattened beforehand. - Otherwise, duplicate items will be removed along the provided axis, - with all the other axes belonging to the each of the unique elements. - Object arrays or structured arrays that contain objects are not - supported if the `axis` kwarg is used. + The axis to operate on. If None, `ar` will be flattened. If an integer, + the subarrays indexed by the given axis will be flattened and treated + as the elements of a 1-D array with the dimension of the given axis, + see the notes for more details. Object arrays or structured arrays + that contain objects are not supported if the `axis` kwarg is used. The + default is None. .. versionadded:: 1.13.0 - - Returns ------- unique : ndarray @@ -170,6 +178,17 @@ def unique(ar, return_index=False, return_inverse=False, numpy.lib.arraysetops : Module with a number of other functions for performing set operations on arrays. + Notes + ----- + When an axis is specified the subarrays indexed by the axis are sorted. + This is done by making the specified axis the first dimension of the array + and then flattening the subarrays in C order. The flattened subarrays are + then viewed as a structured type with each element given a label, with the + effect that we end up with a 1-D array of structured types that can be + treated in the same way as any other 1-D array. The result is that the + flattened subarrays are sorted in lexicographic order starting with the + first element. + Examples -------- >>> np.unique([1, 1, 2, 2, 3, 3]) @@ -211,24 +230,21 @@ def unique(ar, return_index=False, return_inverse=False, """ ar = np.asanyarray(ar) if axis is None: - return _unique1d(ar, return_index, return_inverse, return_counts) - if not (-ar.ndim <= axis < ar.ndim): - raise ValueError('Invalid axis kwarg specified for unique') + ret = _unique1d(ar, return_index, return_inverse, return_counts) + return _unpack_tuple(ret) + + # axis was specified and not None + try: + ar = np.swapaxes(ar, axis, 0) + except np.AxisError: + # this removes the "axis1" or "axis2" prefix from the error message + raise np.AxisError(axis, ar.ndim) - ar = np.swapaxes(ar, axis, 0) - orig_shape, orig_dtype = ar.shape, ar.dtype # Must reshape to a contiguous 2D array for this to work... + orig_shape, orig_dtype = ar.shape, ar.dtype ar = ar.reshape(orig_shape[0], -1) ar = np.ascontiguousarray(ar) - - if ar.dtype.char in (np.typecodes['AllInteger'] + - np.typecodes['Datetime'] + 'S'): - # Optimization: Creating a view of your data with a np.void data type of - # size the number of bytes in a full row. Handles any type where items - # have a unique binary representation, i.e. 0 is only 0, not +0 and -0. - dtype = np.dtype((np.void, ar.dtype.itemsize * ar.shape[1])) - else: - dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])] + dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])] try: consolidated = ar.view(dtype) @@ -245,11 +261,9 @@ def unique(ar, return_index=False, return_inverse=False, output = _unique1d(consolidated, return_index, return_inverse, return_counts) - if not (return_index or return_inverse or return_counts): - return reshape_uniq(output) - else: - uniq = reshape_uniq(output[0]) - return (uniq,) + output[1:] + output = (reshape_uniq(output[0]),) + output[1:] + return _unpack_tuple(output) + def _unique1d(ar, return_index=False, return_inverse=False, return_counts=False): @@ -259,20 +273,6 @@ def _unique1d(ar, return_index=False, return_inverse=False, ar = np.asanyarray(ar).flatten() optional_indices = return_index or return_inverse - optional_returns = optional_indices or return_counts - - if ar.size == 0: - if not optional_returns: - ret = ar - else: - ret = (ar,) - if return_index: - ret += (np.empty(0, np.intp),) - if return_inverse: - ret += (np.empty(0, np.intp),) - if return_counts: - ret += (np.empty(0, np.intp),) - return ret if optional_indices: perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') @@ -280,24 +280,24 @@ def _unique1d(ar, return_index=False, return_inverse=False, else: ar.sort() aux = ar - flag = np.concatenate(([True], aux[1:] != aux[:-1])) - - if not optional_returns: - ret = aux[flag] - else: - ret = (aux[flag],) - if return_index: - ret += (perm[flag],) - if return_inverse: - iflag = np.cumsum(flag) - 1 - inv_idx = np.empty(ar.shape, dtype=np.intp) - inv_idx[perm] = iflag - ret += (inv_idx,) - if return_counts: - idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) - ret += (np.diff(idx),) + mask = np.empty(aux.shape, dtype=np.bool_) + mask[:1] = True + mask[1:] = aux[1:] != aux[:-1] + + ret = (aux[mask],) + if return_index: + ret += (perm[mask],) + if return_inverse: + imask = np.cumsum(mask) - 1 + inv_idx = np.empty(mask.shape, dtype=np.intp) + inv_idx[perm] = imask + ret += (inv_idx,) + if return_counts: + idx = np.concatenate(np.nonzero(mask) + ([mask.size],)) + ret += (np.diff(idx),) return ret + def intersect1d(ar1, ar2, assume_unique=False): """ Find the intersection of two arrays. |