summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py26
1 files changed, 23 insertions, 3 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 2d2e6f337..c4f519d30 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -5,7 +5,7 @@ __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, zeros, newaxis, outerproduct, \
- concatenate, isscalar, array
+ concatenate, isscalar, array, asanyarray
from numpy.core.oldnumeric import product, reshape
def apply_along_axis(func1d,axis,arr,*args):
@@ -544,7 +544,19 @@ def repmat(a, m, n):
return c.reshape(rows, cols)
-# TODO: figure out how to keep arrays the same
+def _getwrapper(*args):
+ """Find the wrapper for the array with the highest priority.
+
+ In case of ties, leftmost wins. If no wrapper is found, return None
+ """
+ wrappers = [(getattr(x, '__array_priority__', 0), -i,
+ x.__array_wrap__) for i, x in enumerate(args)
+ if hasattr(x, '__array_wrap__')]
+ wrappers.sort()
+ if wrappers:
+ return wrappers[-1][-1]
+ return None
+
def kron(a,b):
"""kronecker product of a and b
@@ -553,10 +565,18 @@ def kron(a,b):
[ ... ... ],
[ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]]
"""
+ wrapper = _getwrapper(a, b)
+ a = asanyarray(a)
+ b = asanyarray(b)
+ if not (len(a.shape) == len(b.shape) == 2):
+ raise ValueError("a and b must both be two dimensional")
if not a.flags.contiguous:
a = reshape(a, a.shape)
if not b.flags.contiguous:
b = reshape(b, b.shape)
o = outerproduct(a,b)
o=o.reshape(a.shape + b.shape)
- return concatenate(concatenate(o, axis=1), axis=1)
+ result = concatenate(concatenate(o, axis=1), axis=1)
+ if wrapper is not None:
+ result = wrapper(result)
+ return result