summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/lib/shape_base.py16
1 files changed, 6 insertions, 10 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 581da0598..8371747c0 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1159,9 +1159,11 @@ def kron(a, b):
bs = (1,)*max(0, nda-ndb) + bs
# Compute the product
- a_arr = _nx.asarray(a).reshape(a.size, 1)
- b_arr = _nx.asarray(b).reshape(1, b.size)
- result = a_arr * b_arr
+ a_arr = a.reshape(a.size, 1)
+ b_arr = b.reshape(1, b.size)
+ is_any_mat = isinstance(a_arr, matrix) or isinstance(b_arr, matrix)
+ # In case of `mat`, convert result to `array`
+ result = _nx.multiply(a_arr, b_arr, subok=(not is_any_mat))
# Reshape back
result = result.reshape(as_+bs)
@@ -1169,13 +1171,7 @@ def kron(a, b):
result = result.transpose(transposer)
result = result.reshape(_nx.multiply(as_, bs))
- wrapper = get_array_prepare(a, b)
- if wrapper is not None:
- result = wrapper(result)
- wrapper = get_array_wrap(a, b)
- if wrapper is not None:
- result = wrapper(result)
- return result
+ return result if not is_any_mat else matrix(result, copy=False)
def _tile_dispatcher(A, reps):