summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPim de Haan <pimdehaan@gmail.com>2017-06-10 18:57:04 +0200
committerPim de Haan <pimdehaan@gmail.com>2017-06-22 13:20:47 +0200
commit4f8d9523ef29267932c933bd7e06b0a4e59b1df0 (patch)
tree5d849b49c180461aa0edf67731b5d65a97bfc25a
parentcfb909f35de8ad238066eb176bc408d86f15c9c8 (diff)
downloadnumpy-4f8d9523ef29267932c933bd7e06b0a4e59b1df0.tar.gz
ENH: Make 0-length dim handling of tensordot consistent with dot/einsum
-rw-r--r--doc/release/1.14.0-notes.rst6
-rw-r--r--numpy/core/numeric.py10
-rw-r--r--numpy/core/tests/test_numeric.py11
3 files changed, 22 insertions, 5 deletions
diff --git a/doc/release/1.14.0-notes.rst b/doc/release/1.14.0-notes.rst
index aa4c69f9d..e819bfc2c 100644
--- a/doc/release/1.14.0-notes.rst
+++ b/doc/release/1.14.0-notes.rst
@@ -28,6 +28,12 @@ Build System Changes
Compatibility notes
===================
+``np.tensordot`` now returns zero array when contracting over 0-length dimension
+--------------------------------------------------------------------------------
+Previously ``np.tensordot`` raised a ValueError when contracting over 0-length
+dimension. Now it returns a zero array, which is consistent with the behaviour
+of ``np.dot`` and ``np.einsum``.
+
C API changes
=============
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 013c8a92a..f8ca49429 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -22,9 +22,9 @@ if sys.version_info[0] < 3:
from .multiarray import newbuffer, getbuffer
from . import umath
-from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
- ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG,
- ERR_DEFAULT, PINF, NAN)
+from .umath import (multiply, invert, sin, UFUNC_BUFSIZE_DEFAULT,
+ ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT,
+ ERR_LOG, ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
from ._internal import TooHardError, AxisError
@@ -1326,7 +1326,7 @@ def tensordot(a, b, axes=2):
N2 = 1
for axis in axes_a:
N2 *= as_[axis]
- newshape_a = (-1, N2)
+ newshape_a = (multiply.reduce([as_[ax] for ax in notin]), N2)
olda = [as_[axis] for axis in notin]
notin = [k for k in range(ndb) if k not in axes_b]
@@ -1334,7 +1334,7 @@ def tensordot(a, b, axes=2):
N2 = 1
for axis in axes_b:
N2 *= bs[axis]
- newshape_b = (N2, -1)
+ newshape_b = (N2, multiply.reduce([bs[ax] for ax in notin]))
oldb = [bs[axis] for axis in notin]
at = a.transpose(newaxes_a).reshape(newshape_a)
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 0f87ffdf2..a5149d4f7 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2664,5 +2664,16 @@ class TestKeepdims(TestCase):
assert_raises(TypeError, np.sum, x, keepdims=True)
+class TestTensordot(TestCase):
+
+ def test_zero_dimension(self):
+ # Test resolution to issue #5663
+ a = np.ndarray((3,0))
+ b = np.ndarray((0,4))
+ td = np.tensordot(a, b, (1, 0))
+ assert_array_equal(td, np.dot(a, b))
+ assert_array_equal(td, np.einsum('ij,jk', a, b))
+
+
if __name__ == "__main__":
run_module_suite()