diff options
author | Pim de Haan <pimdehaan@gmail.com> | 2017-06-10 18:57:04 +0200 |
---|---|---|
committer | Pim de Haan <pimdehaan@gmail.com> | 2017-06-22 13:20:47 +0200 |
commit | 4f8d9523ef29267932c933bd7e06b0a4e59b1df0 (patch) | |
tree | 5d849b49c180461aa0edf67731b5d65a97bfc25a /numpy/core/numeric.py | |
parent | cfb909f35de8ad238066eb176bc408d86f15c9c8 (diff) | |
download | numpy-4f8d9523ef29267932c933bd7e06b0a4e59b1df0.tar.gz |
ENH: Make 0-length dim handling of tensordot consistent with dot/einsum
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 013c8a92a..f8ca49429 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -22,9 +22,9 @@ if sys.version_info[0] < 3: from .multiarray import newbuffer, getbuffer from . import umath -from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, - ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG, - ERR_DEFAULT, PINF, NAN) +from .umath import (multiply, invert, sin, UFUNC_BUFSIZE_DEFAULT, + ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, + ERR_LOG, ERR_DEFAULT, PINF, NAN) from . import numerictypes from .numerictypes import longlong, intc, int_, float_, complex_, bool_ from ._internal import TooHardError, AxisError @@ -1326,7 +1326,7 @@ def tensordot(a, b, axes=2): N2 = 1 for axis in axes_a: N2 *= as_[axis] - newshape_a = (-1, N2) + newshape_a = (multiply.reduce([as_[ax] for ax in notin]), N2) olda = [as_[axis] for axis in notin] notin = [k for k in range(ndb) if k not in axes_b] @@ -1334,7 +1334,7 @@ def tensordot(a, b, axes=2): N2 = 1 for axis in axes_b: N2 *= bs[axis] - newshape_b = (N2, -1) + newshape_b = (N2, multiply.reduce([bs[ax] for ax in notin])) oldb = [bs[axis] for axis in notin] at = a.transpose(newaxes_a).reshape(newshape_a) |