diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-26 07:50:53 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-26 07:50:53 +0000 |
commit | 5dd549836b872e7c65f2d6b69e303f953de0f488 (patch) | |
tree | 1c389c7006f47b4497e404e3cb9d68d500d8d8b5 /numpy/lib/function_base.py | |
parent | c37cfa5b256d257a339b30334392c93ae8b7d78a (diff) | |
download | numpy-5dd549836b872e7c65f2d6b69e303f953de0f488.tar.gz |
Fix broadcast-copy on fancy set-item.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 53 |
1 files changed, 36 insertions, 17 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index b03d0ca1a..2d20849c6 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1106,7 +1106,22 @@ def insertinto(arr, obj, values, axis=None): """Return a new array with values inserted along the given axis before the given indices - If axis is None, then ravel the array first. + If axis is None, then ravel the array first. + + The obj argument can be an integer, a slice, or a sequence of + integers. + + Example: + >>> a = array([[1,2,3], + [4,5,6], + [7,8,9]]) + + >>> insertinto(a, [1,2], [[4],[5]], axis=0) + array([[1,2,3], + [4,4,4], + [4,5,6], + [5,5,5], + [7,8,9]) """ arr = asarray(arr) ndim = arr.ndim @@ -1139,28 +1154,32 @@ def insertinto(arr, obj, values, axis=None): elif isinstance(obj, slice): # turn it into a range object obj = arange(*obj.indices(N),**{'dtype':intp}) - - # default behavior - # FIXME: this is too slow - obj = array(obj, dtype=intp, copy=0, ndmin=1) - try: - if len(values) != len(obj): - raise TypeError - except TypeError: - values = [values]*len(obj) - new = arr - k = 0 - for item, val in zip(obj, values): - new = insertinto(new, item+k, val, axis=axis) + + # get two sets of indices + # one is the indices which will hold the new stuff + # two is the indices where arr will be copied over + + obj = asarray(obj, dtype=intp) + numnew = len(obj) + index1 = obj + arange(numnew) + index2 = setdiff1d(arange(numnew+N),index1) + newshape[axis] += numnew + new = empty(newshape, arr.dtype, arr.flags.fnc) + slobj2 = [slice(None)]*ndim + slobj[axis] = index1 + slobj2[axis] = index2 + new[slobj] = values + new[slobj2] = arr + return new -def appendonto(arr, obj, axis=None): +def appendonto(arr, values, axis=None): """Append to the end of an array along axis (ravel first if None) """ arr = asarray(arr) if axis is None: if arr.ndim != 1: arr = arr.ravel() - obj = ravel(obj) + values = ravel(values) axis = 0 - return concatenate((arr, obj), axis=axis) + return concatenate((arr, values), axis=axis) |