summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 69ef0be4f..a5bf4d0ea 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -892,6 +892,19 @@ def dsplit(ary,indices_or_sections):
raise ValueError, 'vsplit only works on arrays of 3 or more dimensions'
return split(ary,indices_or_sections,2)
+def get_array_prepare(*args):
+ """Find the wrapper for the array with the highest priority.
+
+ In case of ties, leftmost wins. If no wrapper is found, return None
+ """
+ wrappers = [(getattr(x, '__array_priority__', 0), -i,
+ x.__array_prepare__) for i, x in enumerate(args)
+ if hasattr(x, '__array_prepare__')]
+ wrappers.sort()
+ if wrappers:
+ return wrappers[-1][-1]
+ return None
+
def get_array_wrap(*args):
"""Find the wrapper for the array with the highest priority.
@@ -975,7 +988,6 @@ def kron(a,b):
True
"""
- wrapper = get_array_wrap(a, b)
b = asanyarray(b)
a = array(a,copy=False,subok=True,ndmin=b.ndim)
ndb, nda = b.ndim, a.ndim
@@ -998,6 +1010,10 @@ def kron(a,b):
axis = nd-1
for _ in xrange(nd):
result = concatenate(result, axis=axis)
+ wrapper = get_array_prepare(a, b)
+ if wrapper is not None:
+ result = wrapper(result)
+ wrapper = get_array_wrap(a, b)
if wrapper is not None:
result = wrapper(result)
return result