summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/function_base.py20
-rw-r--r--numpy/lib/tests/test_function_base.py3
2 files changed, 9 insertions, 14 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index a0781ebf9..f0267afb4 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3519,24 +3519,18 @@ def insert(arr, obj, values, axis=None):
N = arr.shape[axis]
newshape = list(arr.shape)
if isinstance(obj, (int, long, integer)):
+
if (obj < 0): obj += N
if obj < 0 or obj > N:
raise ValueError(
"index (%d) out of range (0<=index<=%d) "\
"in dimension %d" % (obj, N, axis))
- newshape[axis] += 1;
- new = empty(newshape, arr.dtype, arr.flags.fnc)
- slobj[axis] = slice(None, obj)
- new[slobj] = arr[slobj]
- slobj[axis] = obj
- new[slobj] = values
- slobj[axis] = slice(obj+1,None)
- slobj2 = [slice(None)]*ndim
- slobj2[axis] = slice(obj,None)
- new[slobj] = arr[slobj2]
- if wrap:
- return wrap(new)
- return new
+
+ if isinstance(values, (int, long, integer)):
+ obj = [obj]
+ else:
+ obj = [obj] * len(values)
+
elif isinstance(obj, slice):
# turn it into a range object
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 95b32e47c..2ed6e7edd 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -145,7 +145,8 @@ class TestInsert(TestCase):
assert_equal(insert(a, 0, 1), [1, 1, 2, 3])
assert_equal(insert(a, 3, 1), [1, 2, 3, 1])
assert_equal(insert(a, [1, 1, 1], [1, 2, 3]), [1, 1, 2, 3, 2, 3])
-
+ assert_equal(insert(a, 1,[1,2,3]), [1, 1, 2, 3, 2, 3])
+ assert_equal(insert(a,[1,2,3],9),[1,9,2,9,3,9])
class TestAmax(TestCase):
def test_basic(self):