summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-26 08:28:20 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-26 08:28:20 +0000
commit307b1cf34928495500aec4650e8dd6497fdc205c (patch)
tree5b735e1e726325a8d679f056f401495ec401ab36 /numpy/lib/function_base.py
parent5dd549836b872e7c65f2d6b69e303f953de0f488 (diff)
downloadnumpy-307b1cf34928495500aec4650e8dd6497fdc205c.tar.gz
Fix how deletefrom and insertinto handle objects with __array_wrap__ defined.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py56
1 files changed, 45 insertions, 11 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 2d20849c6..cadaa2524 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1036,14 +1036,23 @@ def deletefrom(arr, obj, axis=None):
array([[3,4,5],
[6,7,8]])
"""
+ try:
+ wrap = arr.__array_wrap__
+ except AttributeError:
+ wrap = None
+
arr = asarray(arr)
ndim = arr.ndim
if axis is None:
if ndim != 1:
arr = arr.ravel()
- axis = 0
+ ndim = arr.ndim;
+ axis = ndim-1;
if ndim == 0:
- return arr.copy()
+ if wrap:
+ return wrap(arr)
+ else:
+ return arr.copy()
slobj = [slice(None)]*ndim
N = arr.shape[axis]
newshape = list(arr.shape)
@@ -1059,12 +1068,18 @@ def deletefrom(arr, obj, axis=None):
slobj2 = [slice(None)]*ndim
slobj[axis] = slice(ojb+1,None)
new[slobj] = arr[slobj2]
+ if wrap:
+ return wrap(new)
return new
+
elif isinstance(obj, slice):
start, stop, step = obj.indices(N)
numtodel = len(xrange(start, stop, step))
if numtodel <= 0:
- return arr.copy()
+ if wrap:
+ return wrap(new)
+ else:
+ return arr.copy()
newshape[axis] -= numtodel
new = empty(newshape, arr.dtype, arr.flags.fnc)
# copy initial chunk
@@ -1092,15 +1107,21 @@ def deletefrom(arr, obj, axis=None):
slobj2 = [slice(None)]*ndim
slobj2[axis] = obj
new[slobj] = arr[slobj2]
+ if wrap:
+ return wrap(new)
return new
# default behavior
- obj = array(obj, dtype=intp, copy=0, ndmin=1)
+ obj = array(obj, dtype=intp, copy=0, ndmin=1)
all = arange(N, dtype=intp)
obj = setdiff1d(all, obj)
slobj[axis] = obj
slobj = tuple(slobj)
- return arr[slobj]
+ new = arr[slobj]
+ if wrap:
+ return wrap(new)
+ else:
+ return new
def insertinto(arr, obj, values, axis=None):
"""Return a new array with values inserted along the given axis
@@ -1123,16 +1144,25 @@ def insertinto(arr, obj, values, axis=None):
[5,5,5],
[7,8,9])
"""
+ try:
+ wrap = arr.__array_wrap__
+ except AttributeError:
+ wrap = None
+
arr = asarray(arr)
- ndim = arr.ndim
+ ndim = arr.ndim
if axis is None:
if ndim != 1:
arr = arr.ravel()
- axis = 0
+ ndim = arr.ndim
+ axis = ndim-1
if (ndim == 0):
arr = arr.copy()
arr[...] = values
- return arr
+ if wrap:
+ return wrap(arr)
+ else:
+ return arr
slobj = [slice(None)]*ndim
N = arr.shape[axis]
newshape = list(arr.shape)
@@ -1150,6 +1180,8 @@ def insertinto(arr, obj, values, axis=None):
slobj2 = [slice(None)]*ndim
slobj2[axis] = slice(obj,None)
new[slobj] = arr[slobj2]
+ if wrap:
+ return wrap(new)
return new
elif isinstance(obj, slice):
# turn it into a range object
@@ -1170,16 +1202,18 @@ def insertinto(arr, obj, values, axis=None):
slobj2[axis] = index2
new[slobj] = values
new[slobj2] = arr
-
+
+ if wrap:
+ return wrap(new)
return new
def appendonto(arr, values, axis=None):
"""Append to the end of an array along axis (ravel first if None)
"""
- arr = asarray(arr)
+ arr = asanyarray(arr)
if axis is None:
if arr.ndim != 1:
arr = arr.ravel()
values = ravel(values)
- axis = 0
+ axis = arr.ndim-1
return concatenate((arr, values), axis=axis)