summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 5f21f9b34..ed84e9f5d 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1,7 +1,7 @@
__all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
'column_stack','row_stack', 'dstack','array_split','split','hsplit',
'vsplit','dsplit','apply_over_axes','expand_dims',
- 'apply_along_axis', 'kron', 'tile']
+ 'apply_along_axis', 'kron', 'tile', 'get_array_wrap']
import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, zeros, newaxis, outer, \
@@ -526,7 +526,7 @@ def dsplit(ary,indices_or_sections):
raise ValueError, 'vsplit only works on arrays of 3 or more dimensions'
return split(ary,indices_or_sections,2)
-def _getwrapper(*args):
+def get_array_wrap(*args):
"""Find the wrapper for the array with the highest priority.
In case of ties, leftmost wins. If no wrapper is found, return None
@@ -547,19 +547,28 @@ def kron(a,b):
[ ... ... ],
[ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]]
"""
- wrapper = _getwrapper(a, b)
+ wrapper = get_array_wrap(a, b)
b = asanyarray(b)
a = array(a,copy=False,subok=True,ndmin=b.ndim)
+ ndb, nda = b.ndim, a.ndim
+ if (nda == 0 or ndb == 0):
+ return a * b
as = a.shape
bs = b.shape
if not a.flags.contiguous:
a = reshape(a, as)
if not b.flags.contiguous:
b = reshape(b, bs)
- o = outer(a,b)
- result = o.reshape(as + bs)
- axis = a.ndim-1
- for k in xrange(b.ndim):
+ nd = ndb
+ if (ndb != nda):
+ if (ndb > nda):
+ as = (1,)*(ndb-nda) + as
+ else:
+ bs = (1,)*(nda-ndb) + bs
+ nd = nda
+ result = outer(a,b).reshape(as+bs)
+ axis = nd-1
+ for k in xrange(nd):
result = concatenate(result, axis=axis)
if wrapper is not None:
result = wrapper(result)