diff options
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r-- | numpy/lib/arraysetops.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index 2309f7e42..d65316598 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -94,8 +94,7 @@ def ediff1d(ary, to_end=None, to_begin=None): # force a 1d array ary = np.asanyarray(ary).ravel() - # enforce propagation of the dtype of input - # ary to returned result + # enforce that the dtype of `ary` is used for the output dtype_req = ary.dtype # fast track default case @@ -105,22 +104,23 @@ def ediff1d(ary, to_end=None, to_begin=None): if to_begin is None: l_begin = 0 else: - _to_begin = np.asanyarray(to_begin, dtype=dtype_req) - if not np.all(_to_begin == to_begin): - raise ValueError("cannot convert 'to_begin' to array with dtype " - "'%r' as required for input ary" % dtype_req) - to_begin = _to_begin.ravel() + to_begin = np.asanyarray(to_begin) + if not np.can_cast(to_begin, dtype_req, casting="same_kind"): + raise TypeError("dtype of `to_end` must be compatible " + "with input `ary` under the `same_kind` rule.") + + to_begin = to_begin.ravel() l_begin = len(to_begin) if to_end is None: l_end = 0 else: - _to_end = np.asanyarray(to_end, dtype=dtype_req) - # check that casting has not overflowed - if not np.all(_to_end == to_end): - raise ValueError("cannot convert 'to_end' to array with dtype " - "'%r' as required for input ary" % dtype_req) - to_end = _to_end.ravel() + to_end = np.asanyarray(to_end) + if not np.can_cast(to_end, dtype_req, casting="same_kind"): + raise TypeError("dtype of `to_end` must be compatible " + "with input `ary` under the `same_kind` rule.") + + to_end = to_end.ravel() l_end = len(to_end) # do the calculation in place and copy to_begin and to_end |