diff options
| -rw-r--r-- | numpy/lib/shape_base.py | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 581da0598..8371747c0 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1159,9 +1159,11 @@ def kron(a, b): bs = (1,)*max(0, nda-ndb) + bs # Compute the product - a_arr = _nx.asarray(a).reshape(a.size, 1) - b_arr = _nx.asarray(b).reshape(1, b.size) - result = a_arr * b_arr + a_arr = a.reshape(a.size, 1) + b_arr = b.reshape(1, b.size) + is_any_mat = isinstance(a_arr, matrix) or isinstance(b_arr, matrix) + # In case of `mat`, convert result to `array` + result = _nx.multiply(a_arr, b_arr, subok=(not is_any_mat)) # Reshape back result = result.reshape(as_+bs) @@ -1169,13 +1171,7 @@ def kron(a, b): result = result.transpose(transposer) result = result.reshape(_nx.multiply(as_, bs)) - wrapper = get_array_prepare(a, b) - if wrapper is not None: - result = wrapper(result) - wrapper = get_array_wrap(a, b) - if wrapper is not None: - result = wrapper(result) - return result + return result if not is_any_mat else matrix(result, copy=False) def _tile_dispatcher(A, reps): |
