summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHan Genuit <hangenuit@gmail.com>2012-09-07 01:27:58 +0200
committerHan Genuit <hangenuit@gmail.com>2012-09-07 02:00:10 +0200
commit926564c195d30542312123b7d76fe091c3453881 (patch)
tree93e9e65b93b178af61c7e33395d0bbc37a62a66d
parenta72ce7edc9ff9e98ba73251c626abccb0691415e (diff)
downloadnumpy-926564c195d30542312123b7d76fe091c3453881.tar.gz
BUG: Fix for issues #378 and #392
This should fix the problems with numpy.insert(), where the input values were not checked for all scalar types and where values did not get inserted properly, but got duplicated by default.
-rw-r--r--numpy/lib/function_base.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 247f16560..2b1d780d2 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3591,19 +3591,21 @@ def insert(arr, obj, values, axis=None):
slobj = [slice(None)]*ndim
N = arr.shape[axis]
newshape = list(arr.shape)
- if isinstance(obj, (int, long, integer)):
+ if isinstance(obj, (int, long, integer)):
if (obj < 0): obj += N
if obj < 0 or obj > N:
raise ValueError(
"index (%d) out of range (0<=index<=%d) "\
"in dimension %d" % (obj, N, axis))
-
- if isinstance(values, (int, long, integer)):
- obj = [obj]
+ if isscalar(values):
+ obj = [obj]
else:
- obj = [obj] * len(values)
-
+ values = asarray(values)
+ if ndim > values.ndim:
+ obj = [obj]
+ else:
+ obj = [obj] * len(values)
elif isinstance(obj, slice):
# turn it into a range object