diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 5f21f9b34..ed84e9f5d 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,7 +1,7 @@ __all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack', 'column_stack','row_stack', 'dstack','array_split','split','hsplit', 'vsplit','dsplit','apply_over_axes','expand_dims', - 'apply_along_axis', 'kron', 'tile'] + 'apply_along_axis', 'kron', 'tile', 'get_array_wrap'] import numpy.core.numeric as _nx from numpy.core.numeric import asarray, zeros, newaxis, outer, \ @@ -526,7 +526,7 @@ def dsplit(ary,indices_or_sections): raise ValueError, 'vsplit only works on arrays of 3 or more dimensions' return split(ary,indices_or_sections,2) -def _getwrapper(*args): +def get_array_wrap(*args): """Find the wrapper for the array with the highest priority. In case of ties, leftmost wins. If no wrapper is found, return None @@ -547,19 +547,28 @@ def kron(a,b): [ ... ... ], [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]] """ - wrapper = _getwrapper(a, b) + wrapper = get_array_wrap(a, b) b = asanyarray(b) a = array(a,copy=False,subok=True,ndmin=b.ndim) + ndb, nda = b.ndim, a.ndim + if (nda == 0 or ndb == 0): + return a * b as = a.shape bs = b.shape if not a.flags.contiguous: a = reshape(a, as) if not b.flags.contiguous: b = reshape(b, bs) - o = outer(a,b) - result = o.reshape(as + bs) - axis = a.ndim-1 - for k in xrange(b.ndim): + nd = ndb + if (ndb != nda): + if (ndb > nda): + as = (1,)*(ndb-nda) + as + else: + bs = (1,)*(nda-ndb) + bs + nd = nda + result = outer(a,b).reshape(as+bs) + axis = nd-1 + for k in xrange(nd): result = concatenate(result, axis=axis) if wrapper is not None: result = wrapper(result) |