summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-26 07:50:53 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-26 07:50:53 +0000
commit5dd549836b872e7c65f2d6b69e303f953de0f488 (patch)
tree1c389c7006f47b4497e404e3cb9d68d500d8d8b5 /numpy/lib/function_base.py
parentc37cfa5b256d257a339b30334392c93ae8b7d78a (diff)
downloadnumpy-5dd549836b872e7c65f2d6b69e303f953de0f488.tar.gz
Fix broadcast-copy on fancy set-item.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py53
1 files changed, 36 insertions, 17 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index b03d0ca1a..2d20849c6 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1106,7 +1106,22 @@ def insertinto(arr, obj, values, axis=None):
"""Return a new array with values inserted along the given axis
before the given indices
- If axis is None, then ravel the array first.
+ If axis is None, then ravel the array first.
+
+ The obj argument can be an integer, a slice, or a sequence of
+ integers.
+
+ Example:
+ >>> a = array([[1,2,3],
+ [4,5,6],
+ [7,8,9]])
+
+ >>> insertinto(a, [1,2], [[4],[5]], axis=0)
+ array([[1,2,3],
+ [4,4,4],
+ [4,5,6],
+ [5,5,5],
+ [7,8,9])
"""
arr = asarray(arr)
ndim = arr.ndim
@@ -1139,28 +1154,32 @@ def insertinto(arr, obj, values, axis=None):
elif isinstance(obj, slice):
# turn it into a range object
obj = arange(*obj.indices(N),**{'dtype':intp})
-
- # default behavior
- # FIXME: this is too slow
- obj = array(obj, dtype=intp, copy=0, ndmin=1)
- try:
- if len(values) != len(obj):
- raise TypeError
- except TypeError:
- values = [values]*len(obj)
- new = arr
- k = 0
- for item, val in zip(obj, values):
- new = insertinto(new, item+k, val, axis=axis)
+
+ # get two sets of indices
+ # one is the indices which will hold the new stuff
+ # two is the indices where arr will be copied over
+
+ obj = asarray(obj, dtype=intp)
+ numnew = len(obj)
+ index1 = obj + arange(numnew)
+ index2 = setdiff1d(arange(numnew+N),index1)
+ newshape[axis] += numnew
+ new = empty(newshape, arr.dtype, arr.flags.fnc)
+ slobj2 = [slice(None)]*ndim
+ slobj[axis] = index1
+ slobj2[axis] = index2
+ new[slobj] = values
+ new[slobj2] = arr
+
return new
-def appendonto(arr, obj, axis=None):
+def appendonto(arr, values, axis=None):
"""Append to the end of an array along axis (ravel first if None)
"""
arr = asarray(arr)
if axis is None:
if arr.ndim != 1:
arr = arr.ravel()
- obj = ravel(obj)
+ values = ravel(values)
axis = 0
- return concatenate((arr, obj), axis=axis)
+ return concatenate((arr, values), axis=axis)