diff options
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r-- | numpy/lib/arraysetops.py | 52 |
1 files changed, 31 insertions, 21 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index 2309f7e42..22687b941 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -25,8 +25,6 @@ To do: Optionally return indices analogously to unique for all functions. :Author: Robert Cimrman """ -from __future__ import division, absolute_import, print_function - import functools import numpy as np @@ -94,8 +92,7 @@ def ediff1d(ary, to_end=None, to_begin=None): # force a 1d array ary = np.asanyarray(ary).ravel() - # enforce propagation of the dtype of input - # ary to returned result + # enforce that the dtype of `ary` is used for the output dtype_req = ary.dtype # fast track default case @@ -105,22 +102,23 @@ def ediff1d(ary, to_end=None, to_begin=None): if to_begin is None: l_begin = 0 else: - _to_begin = np.asanyarray(to_begin, dtype=dtype_req) - if not np.all(_to_begin == to_begin): - raise ValueError("cannot convert 'to_begin' to array with dtype " - "'%r' as required for input ary" % dtype_req) - to_begin = _to_begin.ravel() + to_begin = np.asanyarray(to_begin) + if not np.can_cast(to_begin, dtype_req, casting="same_kind"): + raise TypeError("dtype of `to_end` must be compatible " + "with input `ary` under the `same_kind` rule.") + + to_begin = to_begin.ravel() l_begin = len(to_begin) if to_end is None: l_end = 0 else: - _to_end = np.asanyarray(to_end, dtype=dtype_req) - # check that casting has not overflowed - if not np.all(_to_end == to_end): - raise ValueError("cannot convert 'to_end' to array with dtype " - "'%r' as required for input ary" % dtype_req) - to_end = _to_end.ravel() + to_end = np.asanyarray(to_end) + if not np.can_cast(to_end, dtype_req, casting="same_kind"): + raise TypeError("dtype of `to_end` must be compatible " + "with input `ary` under the `same_kind` rule.") + + to_end = to_end.ravel() l_end = len(to_end) # do the calculation in place and copy to_begin and to_end @@ -253,9 +251,9 @@ def unique(ar, return_index=False, return_inverse=False, >>> u array([1, 2, 3, 4, 6]) >>> indices - array([0, 1, 4, ..., 1, 2, 1]) + array([0, 1, 4, 3, 1, 2, 1]) >>> u[indices] - array([1, 2, 6, ..., 2, 3, 2]) + array([1, 2, 6, 4, 2, 3, 2]) """ ar = np.asanyarray(ar) @@ -272,20 +270,33 @@ def unique(ar, return_index=False, return_inverse=False, # Must reshape to a contiguous 2D array for this to work... orig_shape, orig_dtype = ar.shape, ar.dtype - ar = ar.reshape(orig_shape[0], -1) + ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp)) ar = np.ascontiguousarray(ar) dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])] + # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured + # data type with `m` fields where each field has the data type of `ar`. + # In the following, we create the array `consolidated`, which has + # shape `(n,)` with data type `dtype`. try: - consolidated = ar.view(dtype) + if ar.shape[1] > 0: + consolidated = ar.view(dtype) + else: + # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is + # a data type with itemsize 0, and the call `ar.view(dtype)` will + # fail. Instead, we'll use `np.empty` to explicitly create the + # array with shape `(len(ar),)`. Because `dtype` in this case has + # itemsize 0, the total size of the result is still 0 bytes. + consolidated = np.empty(len(ar), dtype=dtype) except TypeError: # There's no good way to do this for object arrays, etc... msg = 'The axis argument to unique is not supported for dtype {dt}' raise TypeError(msg.format(dt=ar.dtype)) def reshape_uniq(uniq): + n = len(uniq) uniq = uniq.view(orig_dtype) - uniq = uniq.reshape(-1, *orig_shape[1:]) + uniq = uniq.reshape(n, *orig_shape[1:]) uniq = np.moveaxis(uniq, 0, axis) return uniq @@ -785,4 +796,3 @@ def setdiff1d(ar1, ar2, assume_unique=False): ar1 = unique(ar1) ar2 = unique(ar2) return ar1[in1d(ar1, ar2, assume_unique=True, invert=True)] - |