summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-03-24 20:30:54 -0600
committerCharles Harris <charlesr.harris@gmail.com>2016-03-24 20:30:54 -0600
commitf2dd0a1a8269c752172165382951ab1b91c6baf7 (patch)
treedd1197c204c6cfc5e4a3101ecb925c7f8759b6ec
parent1380fdd8824f7b404a1a5b22aa4dd6bfbfcfbc1a (diff)
parent582600415d966dbfb7abafd3151415f59c6c422d (diff)
downloadnumpy-f2dd0a1a8269c752172165382951ab1b91c6baf7.tar.gz
Merge pull request #7439 from seberg/gh-7428
BUG: Do not try sequence repeat unless necessary
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src28
-rw-r--r--numpy/core/tests/test_scalarmath.py40
2 files changed, 58 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 3d5e51ba9..08505c5c7 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -264,10 +264,20 @@ gentype_@name@(PyObject *m1, PyObject *m2)
static PyObject *
gentype_multiply(PyObject *m1, PyObject *m2)
{
- PyObject *ret = NULL;
npy_intp repeat;
+ /*
+ * If the other object supports sequence repeat and not number multiply
+ * we should call sequence repeat to support e.g. list repeat by numpy
+ * scalars (they may be converted to ndarray otherwise).
+ * A python defined class will always only have the nb_multiply slot and
+ * some classes may have neither defined. For the latter we want need
+ * to give the normal case a chance to convert the object to ndarray.
+ * Probably no class has both defined, but if they do, prefer number.
+ */
if (!PyArray_IsScalar(m1, Generic) &&
+ ((Py_TYPE(m1)->tp_as_sequence != NULL) &&
+ (Py_TYPE(m1)->tp_as_sequence->sq_repeat != NULL)) &&
((Py_TYPE(m1)->tp_as_number == NULL) ||
(Py_TYPE(m1)->tp_as_number->nb_multiply == NULL))) {
/* Try to convert m2 to an int and try sequence repeat */
@@ -276,9 +286,11 @@ gentype_multiply(PyObject *m1, PyObject *m2)
return NULL;
}
/* Note that npy_intp is compatible to Py_Ssize_t */
- ret = PySequence_Repeat(m1, repeat);
+ return PySequence_Repeat(m1, repeat);
}
- else if (!PyArray_IsScalar(m2, Generic) &&
+ if (!PyArray_IsScalar(m2, Generic) &&
+ ((Py_TYPE(m2)->tp_as_sequence != NULL) &&
+ (Py_TYPE(m2)->tp_as_sequence->sq_repeat != NULL)) &&
((Py_TYPE(m2)->tp_as_number == NULL) ||
(Py_TYPE(m2)->tp_as_number->nb_multiply == NULL))) {
/* Try to convert m1 to an int and try sequence repeat */
@@ -286,13 +298,11 @@ gentype_multiply(PyObject *m1, PyObject *m2)
if (repeat == -1 && PyErr_Occurred()) {
return NULL;
}
- ret = PySequence_Repeat(m2, repeat);
- }
- if (ret == NULL) {
- PyErr_Clear(); /* no effect if not set */
- ret = PyArray_Type.tp_as_number->nb_multiply(m1, m2);
+ return PySequence_Repeat(m2, repeat);
}
- return ret;
+
+ /* All normal cases are handled by PyArray's multiply */
+ return PyArray_Type.tp_as_number->nb_multiply(m1, m2);
}
/**begin repeat
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index 75e91d247..b8f4388b1 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -9,7 +9,7 @@ import numpy as np
from numpy.testing.utils import _gen_alignment_data
from numpy.testing import (
TestCase, run_module_suite, assert_, assert_equal, assert_raises,
- assert_almost_equal, assert_allclose
+ assert_almost_equal, assert_allclose, assert_array_equal
)
types = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc,
@@ -466,6 +466,44 @@ class TestSizeOf(TestCase):
assert_raises(TypeError, d.__sizeof__, "a")
+class TestMultiply(TestCase):
+ def test_seq_repeat(self):
+ # Test that basic sequences get repeated when multiplied with
+ # numpy integers. And errors are raised when multiplied with others.
+ # Some of this behaviour may be controversial and could be open for
+ # change.
+ for seq_type in (list, tuple):
+ seq = seq_type([1, 2, 3])
+ for numpy_type in np.typecodes["AllInteger"]:
+ i = np.dtype(numpy_type).type(2)
+ assert_equal(seq * i, seq * int(i))
+ assert_equal(i * seq, int(i) * seq)
+
+ for numpy_type in np.typecodes["All"].replace("V", ""):
+ if numpy_type in np.typecodes["AllInteger"]:
+ continue
+ i = np.dtype(numpy_type).type()
+ assert_raises(TypeError, operator.mul, seq, i)
+ assert_raises(TypeError, operator.mul, i, seq)
+
+ def test_no_seq_repeat_basic_array_like(self):
+ # Test that an array-like which does not know how to be multiplied
+ # does not attempt sequence repeat (raise TypeError).
+ # See also gh-7428.
+ class ArrayLike(object):
+ def __init__(self, arr):
+ self.arr = arr
+ def __array__(self):
+ return self.arr
+
+ # Test for simple ArrayLike above and memoryviews (original report)
+ for arr_like in (ArrayLike(np.ones(3)), memoryview(np.ones(3))):
+ assert_array_equal(arr_like * np.float32(3.), np.full(3, 3.))
+ assert_array_equal(np.float32(3.) * arr_like, np.full(3, 3.))
+ assert_array_equal(arr_like * np.int_(3), np.full(3, 3))
+ assert_array_equal(np.int_(3) * arr_like, np.full(3, 3))
+
+
class TestAbs(TestCase):
def _test_abs_func(self, absfunc):