diff options
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r-- | numpy/lib/arraysetops.py | 65 |
1 files changed, 57 insertions, 8 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index e8eda297f..4d3f35183 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -298,7 +298,7 @@ def _unique1d(ar, return_index=False, return_inverse=False, return ret -def intersect1d(ar1, ar2, assume_unique=False): +def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): """ Find the intersection of two arrays. @@ -307,15 +307,28 @@ def intersect1d(ar1, ar2, assume_unique=False): Parameters ---------- ar1, ar2 : array_like - Input arrays. + Input arrays. Will be flattened if not already 1D. assume_unique : bool If True, the input arrays are both assumed to be unique, which can speed up the calculation. Default is False. - + return_indices : bool + If True, the indices which correspond to the intersection of the + two arrays are returned. The first instance of a value is used + if there are multiple. Default is False. + + .. versionadded:: 1.15.0 + Returns ------- intersect1d : ndarray Sorted 1D array of common and unique elements. + comm1 : ndarray + The indices of the first occurrences of the common values in `ar1`. + Only provided if `return_indices` is True. + comm2 : ndarray + The indices of the first occurrences of the common values in `ar2`. + Only provided if `return_indices` is True. + See Also -------- @@ -332,14 +345,49 @@ def intersect1d(ar1, ar2, assume_unique=False): >>> from functools import reduce >>> reduce(np.intersect1d, ([1, 3, 4, 3], [3, 1, 2, 1], [6, 3, 4, 2])) array([3]) + + To return the indices of the values common to the input arrays + along with the intersected values: + >>> x = np.array([1, 1, 2, 3, 4]) + >>> y = np.array([2, 1, 4, 6]) + >>> xy, x_ind, y_ind = np.intersect1d(x, y, return_indices=True) + >>> x_ind, y_ind + (array([0, 2, 4]), array([1, 0, 2])) + >>> xy, x[x_ind], y[y_ind] + (array([1, 2, 4]), array([1, 2, 4]), array([1, 2, 4])) + """ if not assume_unique: - # Might be faster than unique( intersect1d( ar1, ar2 ) )? - ar1 = unique(ar1) - ar2 = unique(ar2) + if return_indices: + ar1, ind1 = unique(ar1, return_index=True) + ar2, ind2 = unique(ar2, return_index=True) + else: + ar1 = unique(ar1) + ar2 = unique(ar2) + else: + ar1 = ar1.ravel() + ar2 = ar2.ravel() + aux = np.concatenate((ar1, ar2)) - aux.sort() - return aux[:-1][aux[1:] == aux[:-1]] + if return_indices: + aux_sort_indices = np.argsort(aux, kind='mergesort') + aux = aux[aux_sort_indices] + else: + aux.sort() + + mask = aux[1:] == aux[:-1] + int1d = aux[:-1][mask] + + if return_indices: + ar1_indices = aux_sort_indices[:-1][mask] + ar2_indices = aux_sort_indices[1:][mask] - ar1.size + if not assume_unique: + ar1_indices = ind1[ar1_indices] + ar2_indices = ind2[ar2_indices] + + return int1d, ar1_indices, ar2_indices + else: + return int1d def setxor1d(ar1, ar2, assume_unique=False): """ @@ -660,3 +708,4 @@ def setdiff1d(ar1, ar2, assume_unique=False): ar1 = unique(ar1) ar2 = unique(ar2) return ar1[in1d(ar1, ar2, assume_unique=True, invert=True)] + |