diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2016-03-24 20:30:54 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2016-03-24 20:30:54 -0600 |
commit | f2dd0a1a8269c752172165382951ab1b91c6baf7 (patch) | |
tree | dd1197c204c6cfc5e4a3101ecb925c7f8759b6ec | |
parent | 1380fdd8824f7b404a1a5b22aa4dd6bfbfcfbc1a (diff) | |
parent | 582600415d966dbfb7abafd3151415f59c6c422d (diff) | |
download | numpy-f2dd0a1a8269c752172165382951ab1b91c6baf7.tar.gz |
Merge pull request #7439 from seberg/gh-7428
BUG: Do not try sequence repeat unless necessary
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 28 | ||||
-rw-r--r-- | numpy/core/tests/test_scalarmath.py | 40 |
2 files changed, 58 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 3d5e51ba9..08505c5c7 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -264,10 +264,20 @@ gentype_@name@(PyObject *m1, PyObject *m2) static PyObject * gentype_multiply(PyObject *m1, PyObject *m2) { - PyObject *ret = NULL; npy_intp repeat; + /* + * If the other object supports sequence repeat and not number multiply + * we should call sequence repeat to support e.g. list repeat by numpy + * scalars (they may be converted to ndarray otherwise). + * A python defined class will always only have the nb_multiply slot and + * some classes may have neither defined. For the latter we want need + * to give the normal case a chance to convert the object to ndarray. + * Probably no class has both defined, but if they do, prefer number. + */ if (!PyArray_IsScalar(m1, Generic) && + ((Py_TYPE(m1)->tp_as_sequence != NULL) && + (Py_TYPE(m1)->tp_as_sequence->sq_repeat != NULL)) && ((Py_TYPE(m1)->tp_as_number == NULL) || (Py_TYPE(m1)->tp_as_number->nb_multiply == NULL))) { /* Try to convert m2 to an int and try sequence repeat */ @@ -276,9 +286,11 @@ gentype_multiply(PyObject *m1, PyObject *m2) return NULL; } /* Note that npy_intp is compatible to Py_Ssize_t */ - ret = PySequence_Repeat(m1, repeat); + return PySequence_Repeat(m1, repeat); } - else if (!PyArray_IsScalar(m2, Generic) && + if (!PyArray_IsScalar(m2, Generic) && + ((Py_TYPE(m2)->tp_as_sequence != NULL) && + (Py_TYPE(m2)->tp_as_sequence->sq_repeat != NULL)) && ((Py_TYPE(m2)->tp_as_number == NULL) || (Py_TYPE(m2)->tp_as_number->nb_multiply == NULL))) { /* Try to convert m1 to an int and try sequence repeat */ @@ -286,13 +298,11 @@ gentype_multiply(PyObject *m1, PyObject *m2) if (repeat == -1 && PyErr_Occurred()) { return NULL; } - ret = PySequence_Repeat(m2, repeat); - } - if (ret == NULL) { - PyErr_Clear(); /* no effect if not set */ - ret = PyArray_Type.tp_as_number->nb_multiply(m1, m2); + return PySequence_Repeat(m2, repeat); } - return ret; + + /* All normal cases are handled by PyArray's multiply */ + return PyArray_Type.tp_as_number->nb_multiply(m1, m2); } /**begin repeat diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index 75e91d247..b8f4388b1 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -9,7 +9,7 @@ import numpy as np from numpy.testing.utils import _gen_alignment_data from numpy.testing import ( TestCase, run_module_suite, assert_, assert_equal, assert_raises, - assert_almost_equal, assert_allclose + assert_almost_equal, assert_allclose, assert_array_equal ) types = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, @@ -466,6 +466,44 @@ class TestSizeOf(TestCase): assert_raises(TypeError, d.__sizeof__, "a") +class TestMultiply(TestCase): + def test_seq_repeat(self): + # Test that basic sequences get repeated when multiplied with + # numpy integers. And errors are raised when multiplied with others. + # Some of this behaviour may be controversial and could be open for + # change. + for seq_type in (list, tuple): + seq = seq_type([1, 2, 3]) + for numpy_type in np.typecodes["AllInteger"]: + i = np.dtype(numpy_type).type(2) + assert_equal(seq * i, seq * int(i)) + assert_equal(i * seq, int(i) * seq) + + for numpy_type in np.typecodes["All"].replace("V", ""): + if numpy_type in np.typecodes["AllInteger"]: + continue + i = np.dtype(numpy_type).type() + assert_raises(TypeError, operator.mul, seq, i) + assert_raises(TypeError, operator.mul, i, seq) + + def test_no_seq_repeat_basic_array_like(self): + # Test that an array-like which does not know how to be multiplied + # does not attempt sequence repeat (raise TypeError). + # See also gh-7428. + class ArrayLike(object): + def __init__(self, arr): + self.arr = arr + def __array__(self): + return self.arr + + # Test for simple ArrayLike above and memoryviews (original report) + for arr_like in (ArrayLike(np.ones(3)), memoryview(np.ones(3))): + assert_array_equal(arr_like * np.float32(3.), np.full(3, 3.)) + assert_array_equal(np.float32(3.) * arr_like, np.full(3, 3.)) + assert_array_equal(arr_like * np.int_(3), np.full(3, 3)) + assert_array_equal(np.int_(3) * arr_like, np.full(3, 3)) + + class TestAbs(TestCase): def _test_abs_func(self, absfunc): |