From b00fb8dcaf946fc855970763a15795993626ed0e Mon Sep 17 00:00:00 2001 From: Tim Hochberg Date: Thu, 20 Apr 2006 03:52:14 +0000 Subject: Fix kron so that the return type reflects the type of its arguments. Also, raise an exception if the arguments are not rank-2 since the other cases were some combination of ambiguous or broken. --- numpy/lib/shape_base.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) (limited to 'numpy/lib/shape_base.py') diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 2d2e6f337..c4f519d30 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -5,7 +5,7 @@ __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack', import numpy.core.numeric as _nx from numpy.core.numeric import asarray, zeros, newaxis, outerproduct, \ - concatenate, isscalar, array + concatenate, isscalar, array, asanyarray from numpy.core.oldnumeric import product, reshape def apply_along_axis(func1d,axis,arr,*args): @@ -544,7 +544,19 @@ def repmat(a, m, n): return c.reshape(rows, cols) -# TODO: figure out how to keep arrays the same +def _getwrapper(*args): + """Find the wrapper for the array with the highest priority. + + In case of ties, leftmost wins. If no wrapper is found, return None + """ + wrappers = [(getattr(x, '__array_priority__', 0), -i, + x.__array_wrap__) for i, x in enumerate(args) + if hasattr(x, '__array_wrap__')] + wrappers.sort() + if wrappers: + return wrappers[-1][-1] + return None + def kron(a,b): """kronecker product of a and b @@ -553,10 +565,18 @@ def kron(a,b): [ ... ... ], [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]] """ + wrapper = _getwrapper(a, b) + a = asanyarray(a) + b = asanyarray(b) + if not (len(a.shape) == len(b.shape) == 2): + raise ValueError("a and b must both be two dimensional") if not a.flags.contiguous: a = reshape(a, a.shape) if not b.flags.contiguous: b = reshape(b, b.shape) o = outerproduct(a,b) o=o.reshape(a.shape + b.shape) - return concatenate(concatenate(o, axis=1), axis=1) + result = concatenate(concatenate(o, axis=1), axis=1) + if wrapper is not None: + result = wrapper(result) + return result -- cgit v1.2.1