diff options
author | mattip <matti.picus@gmail.com> | 2017-09-01 01:59:20 +0300 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2017-09-01 01:59:20 +0300 |
commit | 446a78e0358aa84e4875fbaac07c5a2b3c1635e6 (patch) | |
tree | 2c0f6b1e1027c433884b1e55cb68c90c770e6108 | |
parent | 0032e535f7ebbcb4528dbbedb9c71b47914071c7 (diff) | |
download | numpy-446a78e0358aa84e4875fbaac07c5a2b3c1635e6.tar.gz |
BUG: test, fix issue #9620 __radd__ in char scalars
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 14 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 1 | ||||
-rw-r--r-- | numpy/core/tests/test_scalarinherit.py | 36 |
3 files changed, 50 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 5618e2d19..3b2aa8a43 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -195,9 +195,21 @@ gentype_generic_method(PyObject *self, PyObject *args, PyObject *kwds, } } +static PyObject * +gentype_add(PyObject *m1, PyObject* m2) +{ + /* special case str.__radd__, which should not call array_add */ + if (PyString_Check(m1) || PyUnicode_Check(m1)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_add, gentype_add); + return PyArray_Type.tp_as_number->nb_add(m1, m2); +} + /**begin repeat * - * #name = add, subtract, remainder, divmod, lshift, rshift, + * #name = subtract, remainder, divmod, lshift, rshift, * and, xor, or, floor_divide, true_divide# */ static PyObject * diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index a3c94e312..84469d03b 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -2260,5 +2260,6 @@ class TestRegression(object): item2 = copy.copy(item) assert_equal(item, item2) + if __name__ == "__main__": run_module_suite() diff --git a/numpy/core/tests/test_scalarinherit.py b/numpy/core/tests/test_scalarinherit.py index 8e0910d92..c5cd266eb 100644 --- a/numpy/core/tests/test_scalarinherit.py +++ b/numpy/core/tests/test_scalarinherit.py @@ -38,5 +38,41 @@ class TestInherit(object): y = C0(2.0) assert_(str(y) == '2.0') + +class TestCharacter(object): + def test_char_radd(self): + # GH issue 9620, reached gentype_add and raise TypeError + np_s = np.string_('abc') + np_u = np.unicode_('abc') + s = b'def' + u = u'def' + assert_(np_s.__radd__(np_s) is NotImplemented) + assert_(np_s.__radd__(np_u) is NotImplemented) + assert_(np_s.__radd__(s) is NotImplemented) + assert_(np_s.__radd__(u) is NotImplemented) + assert_(np_u.__radd__(np_s) is NotImplemented) + assert_(np_u.__radd__(np_u) is NotImplemented) + assert_(np_u.__radd__(s) is NotImplemented) + assert_(np_u.__radd__(u) is NotImplemented) + assert_(s + np_s == b'defabc') + assert_(u + np_u == u'defabc') + + + class Mystr(str, np.generic): + # would segfault + pass + + ret = s + Mystr('abc') + assert_(type(ret) is type(s)) + + def test_char_repeat(self): + np_s = np.string_('abc') + np_u = np.unicode_('abc') + np_i = np.int(5) + res_np = np_s * np_i + res_s = b'abc' * 5 + assert_(res_np == res_s) + + if __name__ == "__main__": run_module_suite() |