summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2016-03-20 14:32:21 +0100
committerSebastian Berg <sebastian@sipsolutions.net>2016-03-20 14:41:43 +0100
commit582600415d966dbfb7abafd3151415f59c6c422d (patch)
tree25dc92f351133e14be790555b168c09126b5da33
parent41ec9fc8b7ee9b710c3054b9e9409a7a1f6082f7 (diff)
downloadnumpy-582600415d966dbfb7abafd3151415f59c6c422d.tar.gz
BUG: Do not try sequence repeat unless necessary
Only if the other class actually implements sequence repeat, should it be tried. Otherwise classes which implement neither will cause the sequence repeat branch. This branch may fail for two reasons: 1. The scalar was not integer 2. The other object raises an error during repeat. Previously, the first did not happen for floats, etc. and the second was using a try/except logic. Now both will always error. An object which claims to support sequence repeat should never have to fall back to normal multiplication. Note that all of this is controversial in the final behaviour. We may actually want `[1, 2, 3] * np.float64(3.)` to return an array with `[3., 6., 9.]` at some point. - It was suggested to create a better error message, this is not covered here. The major point is fixed though, so: closing gh-7428
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src28
-rw-r--r--numpy/core/tests/test_scalarmath.py40
2 files changed, 58 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 3d5e51ba9..08505c5c7 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -264,10 +264,20 @@ gentype_@name@(PyObject *m1, PyObject *m2)
static PyObject *
gentype_multiply(PyObject *m1, PyObject *m2)
{
- PyObject *ret = NULL;
npy_intp repeat;
+ /*
+ * If the other object supports sequence repeat and not number multiply
+ * we should call sequence repeat to support e.g. list repeat by numpy
+ * scalars (they may be converted to ndarray otherwise).
+ * A python defined class will always only have the nb_multiply slot and
+ * some classes may have neither defined. For the latter we want need
+ * to give the normal case a chance to convert the object to ndarray.
+ * Probably no class has both defined, but if they do, prefer number.
+ */
if (!PyArray_IsScalar(m1, Generic) &&
+ ((Py_TYPE(m1)->tp_as_sequence != NULL) &&
+ (Py_TYPE(m1)->tp_as_sequence->sq_repeat != NULL)) &&
((Py_TYPE(m1)->tp_as_number == NULL) ||
(Py_TYPE(m1)->tp_as_number->nb_multiply == NULL))) {
/* Try to convert m2 to an int and try sequence repeat */
@@ -276,9 +286,11 @@ gentype_multiply(PyObject *m1, PyObject *m2)
return NULL;
}
/* Note that npy_intp is compatible to Py_Ssize_t */
- ret = PySequence_Repeat(m1, repeat);
+ return PySequence_Repeat(m1, repeat);
}
- else if (!PyArray_IsScalar(m2, Generic) &&
+ if (!PyArray_IsScalar(m2, Generic) &&
+ ((Py_TYPE(m2)->tp_as_sequence != NULL) &&
+ (Py_TYPE(m2)->tp_as_sequence->sq_repeat != NULL)) &&
((Py_TYPE(m2)->tp_as_number == NULL) ||
(Py_TYPE(m2)->tp_as_number->nb_multiply == NULL))) {
/* Try to convert m1 to an int and try sequence repeat */
@@ -286,13 +298,11 @@ gentype_multiply(PyObject *m1, PyObject *m2)
if (repeat == -1 && PyErr_Occurred()) {
return NULL;
}
- ret = PySequence_Repeat(m2, repeat);
- }
- if (ret == NULL) {
- PyErr_Clear(); /* no effect if not set */
- ret = PyArray_Type.tp_as_number->nb_multiply(m1, m2);
+ return PySequence_Repeat(m2, repeat);
}
- return ret;
+
+ /* All normal cases are handled by PyArray's multiply */
+ return PyArray_Type.tp_as_number->nb_multiply(m1, m2);
}
/**begin repeat
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index 12b1a0fe3..62faa763f 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -9,7 +9,7 @@ import numpy as np
from numpy.testing.utils import _gen_alignment_data
from numpy.testing import (
TestCase, run_module_suite, assert_, assert_equal, assert_raises,
- assert_almost_equal, assert_allclose
+ assert_almost_equal, assert_allclose, assert_array_equal
)
types = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc,
@@ -448,6 +448,44 @@ class TestSizeOf(TestCase):
assert_raises(TypeError, d.__sizeof__, "a")
+class TestMultiply(TestCase):
+ def test_seq_repeat(self):
+ # Test that basic sequences get repeated when multiplied with
+ # numpy integers. And errors are raised when multiplied with others.
+ # Some of this behaviour may be controversial and could be open for
+ # change.
+ for seq_type in (list, tuple):
+ seq = seq_type([1, 2, 3])
+ for numpy_type in np.typecodes["AllInteger"]:
+ i = np.dtype(numpy_type).type(2)
+ assert_equal(seq * i, seq * int(i))
+ assert_equal(i * seq, int(i) * seq)
+
+ for numpy_type in np.typecodes["All"].replace("V", ""):
+ if numpy_type in np.typecodes["AllInteger"]:
+ continue
+ i = np.dtype(numpy_type).type()
+ assert_raises(TypeError, operator.mul, seq, i)
+ assert_raises(TypeError, operator.mul, i, seq)
+
+ def test_no_seq_repeat_basic_array_like(self):
+ # Test that an array-like which does not know how to be multiplied
+ # does not attempt sequence repeat (raise TypeError).
+ # See also gh-7428.
+ class ArrayLike(object):
+ def __init__(self, arr):
+ self.arr = arr
+ def __array__(self):
+ return self.arr
+
+ # Test for simple ArrayLike above and memoryviews (original report)
+ for arr_like in (ArrayLike(np.ones(3)), memoryview(np.ones(3))):
+ assert_array_equal(arr_like * np.float32(3.), np.full(3, 3.))
+ assert_array_equal(np.float32(3.) * arr_like, np.full(3, 3.))
+ assert_array_equal(arr_like * np.int_(3), np.full(3, 3))
+ assert_array_equal(np.int_(3) * arr_like, np.full(3, 3))
+
+
class TestAbs(TestCase):
def _test_abs_func(self, absfunc):