summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2017-09-01 01:59:20 +0300
committermattip <matti.picus@gmail.com>2017-09-01 01:59:20 +0300
commit446a78e0358aa84e4875fbaac07c5a2b3c1635e6 (patch)
tree2c0f6b1e1027c433884b1e55cb68c90c770e6108
parent0032e535f7ebbcb4528dbbedb9c71b47914071c7 (diff)
downloadnumpy-446a78e0358aa84e4875fbaac07c5a2b3c1635e6.tar.gz
BUG: test, fix issue #9620 __radd__ in char scalars
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src14
-rw-r--r--numpy/core/tests/test_regression.py1
-rw-r--r--numpy/core/tests/test_scalarinherit.py36
3 files changed, 50 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 5618e2d19..3b2aa8a43 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -195,9 +195,21 @@ gentype_generic_method(PyObject *self, PyObject *args, PyObject *kwds,
}
}
+static PyObject *
+gentype_add(PyObject *m1, PyObject* m2)
+{
+ /* special case str.__radd__, which should not call array_add */
+ if (PyString_Check(m1) || PyUnicode_Check(m1)) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_add, gentype_add);
+ return PyArray_Type.tp_as_number->nb_add(m1, m2);
+}
+
/**begin repeat
*
- * #name = add, subtract, remainder, divmod, lshift, rshift,
+ * #name = subtract, remainder, divmod, lshift, rshift,
* and, xor, or, floor_divide, true_divide#
*/
static PyObject *
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index a3c94e312..84469d03b 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -2260,5 +2260,6 @@ class TestRegression(object):
item2 = copy.copy(item)
assert_equal(item, item2)
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/core/tests/test_scalarinherit.py b/numpy/core/tests/test_scalarinherit.py
index 8e0910d92..c5cd266eb 100644
--- a/numpy/core/tests/test_scalarinherit.py
+++ b/numpy/core/tests/test_scalarinherit.py
@@ -38,5 +38,41 @@ class TestInherit(object):
y = C0(2.0)
assert_(str(y) == '2.0')
+
+class TestCharacter(object):
+ def test_char_radd(self):
+ # GH issue 9620, reached gentype_add and raise TypeError
+ np_s = np.string_('abc')
+ np_u = np.unicode_('abc')
+ s = b'def'
+ u = u'def'
+ assert_(np_s.__radd__(np_s) is NotImplemented)
+ assert_(np_s.__radd__(np_u) is NotImplemented)
+ assert_(np_s.__radd__(s) is NotImplemented)
+ assert_(np_s.__radd__(u) is NotImplemented)
+ assert_(np_u.__radd__(np_s) is NotImplemented)
+ assert_(np_u.__radd__(np_u) is NotImplemented)
+ assert_(np_u.__radd__(s) is NotImplemented)
+ assert_(np_u.__radd__(u) is NotImplemented)
+ assert_(s + np_s == b'defabc')
+ assert_(u + np_u == u'defabc')
+
+
+ class Mystr(str, np.generic):
+ # would segfault
+ pass
+
+ ret = s + Mystr('abc')
+ assert_(type(ret) is type(s))
+
+ def test_char_repeat(self):
+ np_s = np.string_('abc')
+ np_u = np.unicode_('abc')
+ np_i = np.int(5)
+ res_np = np_s * np_i
+ res_s = b'abc' * 5
+ assert_(res_np == res_s)
+
+
if __name__ == "__main__":
run_module_suite()