summaryrefslogtreecommitdiff
path: root/numpy/lib/arraysetops.py
diff options
context:
space:
mode:
authorWarren Weckesser <warren.weckesser@gmail.com>2020-03-12 19:53:13 -0400
committerWarren Weckesser <warren.weckesser@gmail.com>2020-03-15 11:35:39 -0400
commitd18156b00d4215d869d71f1d4c7d66037bd50b4b (patch)
tree11f7955d2bdb480efecc63cadb4b56c744597597 /numpy/lib/arraysetops.py
parent353bf737fb84fe285b430e6496aa7537f56a674f (diff)
downloadnumpy-d18156b00d4215d869d71f1d4c7d66037bd50b4b.tar.gz
BUG: lib: Handle axes with length 0 in np.unique.
Tweak a few lines so that arrays with an axis with length 0 don't break the np.unique code. Closes gh-15559.
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r--numpy/lib/arraysetops.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index 0f2d082c5..22687b941 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -270,20 +270,33 @@ def unique(ar, return_index=False, return_inverse=False,
# Must reshape to a contiguous 2D array for this to work...
orig_shape, orig_dtype = ar.shape, ar.dtype
- ar = ar.reshape(orig_shape[0], -1)
+ ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))
ar = np.ascontiguousarray(ar)
dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]
+ # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured
+ # data type with `m` fields where each field has the data type of `ar`.
+ # In the following, we create the array `consolidated`, which has
+ # shape `(n,)` with data type `dtype`.
try:
- consolidated = ar.view(dtype)
+ if ar.shape[1] > 0:
+ consolidated = ar.view(dtype)
+ else:
+ # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is
+ # a data type with itemsize 0, and the call `ar.view(dtype)` will
+ # fail. Instead, we'll use `np.empty` to explicitly create the
+ # array with shape `(len(ar),)`. Because `dtype` in this case has
+ # itemsize 0, the total size of the result is still 0 bytes.
+ consolidated = np.empty(len(ar), dtype=dtype)
except TypeError:
# There's no good way to do this for object arrays, etc...
msg = 'The axis argument to unique is not supported for dtype {dt}'
raise TypeError(msg.format(dt=ar.dtype))
def reshape_uniq(uniq):
+ n = len(uniq)
uniq = uniq.view(orig_dtype)
- uniq = uniq.reshape(-1, *orig_shape[1:])
+ uniq = uniq.reshape(n, *orig_shape[1:])
uniq = np.moveaxis(uniq, 0, axis)
return uniq
@@ -783,4 +796,3 @@ def setdiff1d(ar1, ar2, assume_unique=False):
ar1 = unique(ar1)
ar2 = unique(ar2)
return ar1[in1d(ar1, ar2, assume_unique=True, invert=True)]
-