summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2007-09-10 22:02:25 +0000
committerGuido van Rossum <guido@python.org>2007-09-10 22:02:25 +0000
commit1ff91d95a280449cfd9c723a081cb7b19a52e758 (patch)
treec537822cc870185f7042767b7ed5ca40b7d5da50
parent98d19dafd9c9d95338887b9e53c77ec6960918e0 (diff)
downloadcpython-git-1ff91d95a280449cfd9c723a081cb7b19a52e758.tar.gz
Patch # 1140 (my code, approved by Effbot).
Make sure the type of the return value of re.sub(x, y, z) is the type of y+x (i.e. unicode if either is unicode, str if they are both str) even if there are no substitutions or if x==z (which triggered various special cases in join_list()). Could be backported to 2.5; no need to port to 3.0.
-rw-r--r--Lib/test/test_re.py25
-rw-r--r--Modules/_sre.c25
2 files changed, 33 insertions, 17 deletions
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index cfb949c4e3..aa403bac1a 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -83,6 +83,31 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.sub('\r\n', '\n', 'abc\r\ndef\r\n'),
'abc\ndef\n')
+ def test_bug_1140(self):
+ # re.sub(x, y, u'') should return u'', not '', and
+ # re.sub(x, y, '') should return '', not u''.
+ # Also:
+ # re.sub(x, y, unicode(x)) should return unicode(y), and
+ # re.sub(x, y, str(x)) should return
+ # str(y) if isinstance(y, str) else unicode(y).
+ for x in 'x', u'x':
+ for y in 'y', u'y':
+ z = re.sub(x, y, u'')
+ self.assertEqual(z, u'')
+ self.assertEqual(type(z), unicode)
+ #
+ z = re.sub(x, y, '')
+ self.assertEqual(z, '')
+ self.assertEqual(type(z), str)
+ #
+ z = re.sub(x, y, unicode(x))
+ self.assertEqual(z, y)
+ self.assertEqual(type(z), unicode)
+ #
+ z = re.sub(x, y, str(x))
+ self.assertEqual(z, y)
+ self.assertEqual(type(z), type(y))
+
def test_sub_template_numeric_escape(self):
# bug 776311 and friends
self.assertEqual(re.sub('x', r'\0', 'x'), '\0')
diff --git a/Modules/_sre.c b/Modules/_sre.c
index 7dafaeb47a..51a73483d1 100644
--- a/Modules/_sre.c
+++ b/Modules/_sre.c
@@ -1979,7 +1979,7 @@ deepcopy(PyObject** object, PyObject* memo)
#endif
static PyObject*
-join_list(PyObject* list, PyObject* pattern)
+join_list(PyObject* list, PyObject* string)
{
/* join list elements */
@@ -1990,24 +1990,15 @@ join_list(PyObject* list, PyObject* pattern)
#endif
PyObject* result;
- switch (PyList_GET_SIZE(list)) {
- case 0:
- Py_DECREF(list);
- return PySequence_GetSlice(pattern, 0, 0);
- case 1:
- result = PyList_GET_ITEM(list, 0);
- Py_INCREF(result);
- Py_DECREF(list);
- return result;
- }
-
- /* two or more elements: slice out a suitable separator from the
- first member, and use that to join the entire list */
-
- joiner = PySequence_GetSlice(pattern, 0, 0);
+ joiner = PySequence_GetSlice(string, 0, 0);
if (!joiner)
return NULL;
+ if (PyList_GET_SIZE(list) == 0) {
+ Py_DECREF(list);
+ return joiner;
+ }
+
#if PY_VERSION_HEX >= 0x01060000
function = PyObject_GetAttrString(joiner, "join");
if (!function) {
@@ -2443,7 +2434,7 @@ next:
Py_DECREF(filter);
/* convert list to single string (also removes list) */
- item = join_list(list, self->pattern);
+ item = join_list(list, string);
if (!item)
return NULL;