diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-26 08:28:20 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-26 08:28:20 +0000 |
commit | 307b1cf34928495500aec4650e8dd6497fdc205c (patch) | |
tree | 5b735e1e726325a8d679f056f401495ec401ab36 /numpy/lib/function_base.py | |
parent | 5dd549836b872e7c65f2d6b69e303f953de0f488 (diff) | |
download | numpy-307b1cf34928495500aec4650e8dd6497fdc205c.tar.gz |
Fix how deletefrom and insertinto handle objects with __array_wrap__ defined.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 56 |
1 files changed, 45 insertions, 11 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 2d20849c6..cadaa2524 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1036,14 +1036,23 @@ def deletefrom(arr, obj, axis=None): array([[3,4,5], [6,7,8]]) """ + try: + wrap = arr.__array_wrap__ + except AttributeError: + wrap = None + arr = asarray(arr) ndim = arr.ndim if axis is None: if ndim != 1: arr = arr.ravel() - axis = 0 + ndim = arr.ndim; + axis = ndim-1; if ndim == 0: - return arr.copy() + if wrap: + return wrap(arr) + else: + return arr.copy() slobj = [slice(None)]*ndim N = arr.shape[axis] newshape = list(arr.shape) @@ -1059,12 +1068,18 @@ def deletefrom(arr, obj, axis=None): slobj2 = [slice(None)]*ndim slobj[axis] = slice(ojb+1,None) new[slobj] = arr[slobj2] + if wrap: + return wrap(new) return new + elif isinstance(obj, slice): start, stop, step = obj.indices(N) numtodel = len(xrange(start, stop, step)) if numtodel <= 0: - return arr.copy() + if wrap: + return wrap(new) + else: + return arr.copy() newshape[axis] -= numtodel new = empty(newshape, arr.dtype, arr.flags.fnc) # copy initial chunk @@ -1092,15 +1107,21 @@ def deletefrom(arr, obj, axis=None): slobj2 = [slice(None)]*ndim slobj2[axis] = obj new[slobj] = arr[slobj2] + if wrap: + return wrap(new) return new # default behavior - obj = array(obj, dtype=intp, copy=0, ndmin=1) + obj = array(obj, dtype=intp, copy=0, ndmin=1) all = arange(N, dtype=intp) obj = setdiff1d(all, obj) slobj[axis] = obj slobj = tuple(slobj) - return arr[slobj] + new = arr[slobj] + if wrap: + return wrap(new) + else: + return new def insertinto(arr, obj, values, axis=None): """Return a new array with values inserted along the given axis @@ -1123,16 +1144,25 @@ def insertinto(arr, obj, values, axis=None): [5,5,5], [7,8,9]) """ + try: + wrap = arr.__array_wrap__ + except AttributeError: + wrap = None + arr = asarray(arr) - ndim = arr.ndim + ndim = arr.ndim if axis is None: if ndim != 1: arr = arr.ravel() - axis = 0 + ndim = arr.ndim + axis = ndim-1 if (ndim == 0): arr = arr.copy() arr[...] = values - return arr + if wrap: + return wrap(arr) + else: + return arr slobj = [slice(None)]*ndim N = arr.shape[axis] newshape = list(arr.shape) @@ -1150,6 +1180,8 @@ def insertinto(arr, obj, values, axis=None): slobj2 = [slice(None)]*ndim slobj2[axis] = slice(obj,None) new[slobj] = arr[slobj2] + if wrap: + return wrap(new) return new elif isinstance(obj, slice): # turn it into a range object @@ -1170,16 +1202,18 @@ def insertinto(arr, obj, values, axis=None): slobj2[axis] = index2 new[slobj] = values new[slobj2] = arr - + + if wrap: + return wrap(new) return new def appendonto(arr, values, axis=None): """Append to the end of an array along axis (ravel first if None) """ - arr = asarray(arr) + arr = asanyarray(arr) if axis is None: if arr.ndim != 1: arr = arr.ravel() values = ravel(values) - axis = 0 + axis = arr.ndim-1 return concatenate((arr, values), axis=axis) |