diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 69ef0be4f..a5bf4d0ea 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -892,6 +892,19 @@ def dsplit(ary,indices_or_sections): raise ValueError, 'vsplit only works on arrays of 3 or more dimensions' return split(ary,indices_or_sections,2) +def get_array_prepare(*args): + """Find the wrapper for the array with the highest priority. + + In case of ties, leftmost wins. If no wrapper is found, return None + """ + wrappers = [(getattr(x, '__array_priority__', 0), -i, + x.__array_prepare__) for i, x in enumerate(args) + if hasattr(x, '__array_prepare__')] + wrappers.sort() + if wrappers: + return wrappers[-1][-1] + return None + def get_array_wrap(*args): """Find the wrapper for the array with the highest priority. @@ -975,7 +988,6 @@ def kron(a,b): True """ - wrapper = get_array_wrap(a, b) b = asanyarray(b) a = array(a,copy=False,subok=True,ndmin=b.ndim) ndb, nda = b.ndim, a.ndim @@ -998,6 +1010,10 @@ def kron(a,b): axis = nd-1 for _ in xrange(nd): result = concatenate(result, axis=axis) + wrapper = get_array_prepare(a, b) + if wrapper is not None: + result = wrapper(result) + wrapper = get_array_wrap(a, b) if wrapper is not None: result = wrapper(result) return result |