summaryrefslogtreecommitdiff
path: root/numpy/lib/arraysetops.py
diff options
context:
space:
mode:
authorTyler Reddy <tyler.je.reddy@gmail.com>2018-08-26 12:55:22 -0700
committerTyler Reddy <tyler.je.reddy@gmail.com>2018-08-26 12:55:22 -0700
commit6171296a661065de0d40c7be4c12666fcfaafd2b (patch)
treea9a530adf36d5debbcb1f144fcd23e33f461103f /numpy/lib/arraysetops.py
parent63a57c3f6a58c61763daf89a2d5495f6a855bf9b (diff)
downloadnumpy-6171296a661065de0d40c7be4c12666fcfaafd2b.tar.gz
BUG: add type cast check to ediff1d
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r--numpy/lib/arraysetops.py19
1 files changed, 17 insertions, 2 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index d84455a8f..62e9b6d50 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -82,6 +82,11 @@ def ediff1d(ary, to_end=None, to_begin=None):
# force a 1d array
ary = np.asanyarray(ary).ravel()
+ # we have unit tests enforcing
+ # propagation of the dtype of input
+ # ary to returned result
+ dtype_req = ary.dtype
+
# fast track default case
if to_begin is None and to_end is None:
return ary[1:] - ary[:-1]
@@ -89,13 +94,23 @@ def ediff1d(ary, to_end=None, to_begin=None):
if to_begin is None:
l_begin = 0
else:
- to_begin = np.asanyarray(to_begin).ravel()
+ to_begin = np.asanyarray(to_begin)
+ if not np.can_cast(to_begin, dtype_req):
+ raise TypeError("dtype of to_begin must be compatible "
+ "with input ary")
+
+ to_begin = to_begin.ravel()
l_begin = len(to_begin)
if to_end is None:
l_end = 0
else:
- to_end = np.asanyarray(to_end).ravel()
+ to_end = np.asanyarray(to_end)
+ if not np.can_cast(to_end, dtype_req):
+ raise TypeError("dtype of to_end must be compatible "
+ "with input ary")
+
+ to_end = to_end.ravel()
l_end = len(to_end)
# do the calculation in place and copy to_begin and to_end