summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorBarry Warsaw <barry@python.org>2012-02-20 20:42:21 -0500
committerBarry Warsaw <barry@python.org>2012-02-20 20:42:21 -0500
commit1e13eb084f72d5993cbb726e45b36bdb69c83a24 (patch)
tree1db691c15c5980a870bcc2606a6d2afc77e28bad /Lib
parentf5a5beb33985b4b55480de267084b90d89a5c5c4 (diff)
downloadcpython-git-1e13eb084f72d5993cbb726e45b36bdb69c83a24.tar.gz
- Issue #13703: oCERT-2011-003: add -R command-line option and PYTHONHASHSEED
environment variable, to provide an opt-in way to protect against denial of service attacks due to hash collisions within the dict and set types. Patch by David Malcolm, based on work by Victor Stinner.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/os.py19
-rw-r--r--Lib/test/test_cmd_line.py14
-rw-r--r--Lib/test/test_hash.py100
-rw-r--r--Lib/test/test_os.py54
-rw-r--r--Lib/test/test_set.py52
-rw-r--r--Lib/test/test_support.py12
-rw-r--r--Lib/test/test_symtable.py7
-rw-r--r--Lib/test/test_sys.py2
8 files changed, 222 insertions, 38 deletions
diff --git a/Lib/os.py b/Lib/os.py
index 88adc1555a..a3fd71185f 100644
--- a/Lib/os.py
+++ b/Lib/os.py
@@ -742,22 +742,3 @@ try:
_make_statvfs_result)
except NameError: # statvfs_result may not exist
pass
-
-if not _exists("urandom"):
- def urandom(n):
- """urandom(n) -> str
-
- Return a string of n random bytes suitable for cryptographic use.
-
- """
- try:
- _urandomfd = open("/dev/urandom", O_RDONLY)
- except (OSError, IOError):
- raise NotImplementedError("/dev/urandom (or equivalent) not found")
- try:
- bs = b""
- while n - len(bs) >= 1:
- bs += read(_urandomfd, n - len(bs))
- finally:
- close(_urandomfd)
- return bs
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index efef74f09c..28362dffbf 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -103,6 +103,20 @@ class CmdLineTest(unittest.TestCase):
self.exit_code('-c', 'pass'),
0)
+ def test_hash_randomization(self):
+ # Verify that -R enables hash randomization:
+ self.verify_valid_flag('-R')
+ hashes = []
+ for i in range(2):
+ code = 'print(hash("spam"))'
+ data = self.start_python('-R', '-c', code)
+ hashes.append(data)
+ self.assertNotEqual(hashes[0], hashes[1])
+
+ # Verify that sys.flags contains hash_randomization
+ code = 'import sys; print sys.flags'
+ data = self.start_python('-R', '-c', code)
+ self.assertTrue('hash_randomization=1' in data)
def test_main():
test.test_support.run_unittest(CmdLineTest)
diff --git a/Lib/test/test_hash.py b/Lib/test/test_hash.py
index 7ce40b95b7..1a982c473d 100644
--- a/Lib/test/test_hash.py
+++ b/Lib/test/test_hash.py
@@ -3,10 +3,18 @@
#
# Also test that hash implementations are inherited as expected
+import os
+import sys
+import struct
+import datetime
import unittest
+import subprocess
+
from test import test_support
from collections import Hashable
+IS_64BIT = (struct.calcsize('l') == 8)
+
class HashEqualityTestCase(unittest.TestCase):
@@ -133,10 +141,100 @@ class HashBuiltinsTestCase(unittest.TestCase):
for obj in self.hashes_to_check:
self.assertEqual(hash(obj), _default_hash(obj))
+class HashRandomizationTests(unittest.TestCase):
+
+ # Each subclass should define a field "repr_", containing the repr() of
+ # an object to be tested
+
+ def get_hash_command(self, repr_):
+ return 'print(hash(%s))' % repr_
+
+ def get_hash(self, repr_, seed=None):
+ env = os.environ.copy()
+ if seed is not None:
+ env['PYTHONHASHSEED'] = str(seed)
+ else:
+ env.pop('PYTHONHASHSEED', None)
+ cmd_line = [sys.executable, '-c', self.get_hash_command(repr_)]
+ p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
+ env=env)
+ out, err = p.communicate()
+ out = test_support.strip_python_stderr(out)
+ return int(out.strip())
+
+ def test_randomized_hash(self):
+ # two runs should return different hashes
+ run1 = self.get_hash(self.repr_, seed='random')
+ run2 = self.get_hash(self.repr_, seed='random')
+ self.assertNotEqual(run1, run2)
+
+class StringlikeHashRandomizationTests(HashRandomizationTests):
+ def test_null_hash(self):
+ # PYTHONHASHSEED=0 disables the randomized hash
+ if IS_64BIT:
+ known_hash_of_obj = 1453079729188098211
+ else:
+ known_hash_of_obj = -1600925533
+
+ # Randomization is disabled by default:
+ self.assertEqual(self.get_hash(self.repr_), known_hash_of_obj)
+
+ # It can also be disabled by setting the seed to 0:
+ self.assertEqual(self.get_hash(self.repr_, seed=0), known_hash_of_obj)
+
+ def test_fixed_hash(self):
+ # test a fixed seed for the randomized hash
+ # Note that all types share the same values:
+ if IS_64BIT:
+ h = -4410911502303878509
+ else:
+ h = -206076799
+ self.assertEqual(self.get_hash(self.repr_, seed=42), h)
+
+class StrHashRandomizationTests(StringlikeHashRandomizationTests):
+ repr_ = repr('abc')
+
+ def test_empty_string(self):
+ self.assertEqual(hash(""), 0)
+
+class UnicodeHashRandomizationTests(StringlikeHashRandomizationTests):
+ repr_ = repr(u'abc')
+
+ def test_empty_string(self):
+ self.assertEqual(hash(u""), 0)
+
+class BufferHashRandomizationTests(StringlikeHashRandomizationTests):
+ repr_ = 'buffer("abc")'
+
+ def test_empty_string(self):
+ self.assertEqual(hash(buffer("")), 0)
+
+class DatetimeTests(HashRandomizationTests):
+ def get_hash_command(self, repr_):
+ return 'import datetime; print(hash(%s))' % repr_
+
+class DatetimeDateTests(DatetimeTests):
+ repr_ = repr(datetime.date(1066, 10, 14))
+
+class DatetimeDatetimeTests(DatetimeTests):
+ repr_ = repr(datetime.datetime(1, 2, 3, 4, 5, 6, 7))
+
+class DatetimeTimeTests(DatetimeTests):
+ repr_ = repr(datetime.time(0))
+
+
def test_main():
test_support.run_unittest(HashEqualityTestCase,
HashInheritanceTestCase,
- HashBuiltinsTestCase)
+ HashBuiltinsTestCase,
+ StrHashRandomizationTests,
+ UnicodeHashRandomizationTests,
+ BufferHashRandomizationTests,
+ DatetimeDateTests,
+ DatetimeDatetimeTests,
+ DatetimeTimeTests)
+
if __name__ == "__main__":
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index db7e9b4105..0561499ed7 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -6,6 +6,8 @@ import os
import unittest
import warnings
import sys
+import subprocess
+
from test import test_support
warnings.filterwarnings("ignore", "tempnam", RuntimeWarning, __name__)
@@ -499,18 +501,46 @@ class DevNullTests (unittest.TestCase):
class URandomTests (unittest.TestCase):
def test_urandom(self):
- try:
- with test_support.check_warnings():
- self.assertEqual(len(os.urandom(1)), 1)
- self.assertEqual(len(os.urandom(10)), 10)
- self.assertEqual(len(os.urandom(100)), 100)
- self.assertEqual(len(os.urandom(1000)), 1000)
- # see http://bugs.python.org/issue3708
- self.assertEqual(len(os.urandom(0.9)), 0)
- self.assertEqual(len(os.urandom(1.1)), 1)
- self.assertEqual(len(os.urandom(2.0)), 2)
- except NotImplementedError:
- pass
+ with test_support.check_warnings():
+ self.assertEqual(len(os.urandom(1)), 1)
+ self.assertEqual(len(os.urandom(10)), 10)
+ self.assertEqual(len(os.urandom(100)), 100)
+ self.assertEqual(len(os.urandom(1000)), 1000)
+ # see http://bugs.python.org/issue3708
+ self.assertEqual(len(os.urandom(0.9)), 0)
+ self.assertEqual(len(os.urandom(1.1)), 1)
+ self.assertEqual(len(os.urandom(2.0)), 2)
+
+ def test_urandom_length(self):
+ self.assertEqual(len(os.urandom(0)), 0)
+ self.assertEqual(len(os.urandom(1)), 1)
+ self.assertEqual(len(os.urandom(10)), 10)
+ self.assertEqual(len(os.urandom(100)), 100)
+ self.assertEqual(len(os.urandom(1000)), 1000)
+
+ def test_urandom_value(self):
+ data1 = os.urandom(16)
+ data2 = os.urandom(16)
+ self.assertNotEqual(data1, data2)
+
+ def get_urandom_subprocess(self, count):
+ code = '\n'.join((
+ 'import os, sys',
+ 'data = os.urandom(%s)' % count,
+ 'sys.stdout.write(data)',
+ 'sys.stdout.flush()'))
+ cmd_line = [sys.executable, '-c', code]
+ p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ out, err = p.communicate()
+ out = test_support.strip_python_stderr(out)
+ self.assertEqual(len(out), count)
+ return out
+
+ def test_urandom_subprocess(self):
+ data1 = self.get_urandom_subprocess(16)
+ data2 = self.get_urandom_subprocess(16)
+ self.assertNotEqual(data1, data2)
class Win32ErrorTests(unittest.TestCase):
def test_rename(self):
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 3539a14065..18822cad59 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -6,7 +6,6 @@ import weakref
import operator
import copy
import pickle
-import os
from random import randrange, shuffle
import sys
import collections
@@ -688,6 +687,17 @@ class TestBasicOps(unittest.TestCase):
if self.repr is not None:
self.assertEqual(repr(self.set), self.repr)
+ def check_repr_against_values(self):
+ text = repr(self.set)
+ self.assertTrue(text.startswith('{'))
+ self.assertTrue(text.endswith('}'))
+
+ result = text[1:-1].split(', ')
+ result.sort()
+ sorted_repr_values = [repr(value) for value in self.values]
+ sorted_repr_values.sort()
+ self.assertEqual(result, sorted_repr_values)
+
def test_print(self):
fo = open(test_support.TESTFN, "wb")
try:
@@ -837,6 +847,46 @@ class TestBasicOpsTriple(TestBasicOps):
self.length = 3
self.repr = None
+#------------------------------------------------------------------------------
+
+class TestBasicOpsString(TestBasicOps):
+ def setUp(self):
+ self.case = "string set"
+ self.values = ["a", "b", "c"]
+ self.set = set(self.values)
+ self.dup = set(self.values)
+ self.length = 3
+
+ def test_repr(self):
+ self.check_repr_against_values()
+
+#------------------------------------------------------------------------------
+
+class TestBasicOpsUnicode(TestBasicOps):
+ def setUp(self):
+ self.case = "unicode set"
+ self.values = [u"a", u"b", u"c"]
+ self.set = set(self.values)
+ self.dup = set(self.values)
+ self.length = 3
+
+ def test_repr(self):
+ self.check_repr_against_values()
+
+#------------------------------------------------------------------------------
+
+class TestBasicOpsMixedStringUnicode(TestBasicOps):
+ def setUp(self):
+ self.case = "string and bytes set"
+ self.values = ["a", "b", u"a", u"b"]
+ self.set = set(self.values)
+ self.dup = set(self.values)
+ self.length = 4
+
+ def test_repr(self):
+ with test_support.check_warnings():
+ self.check_repr_against_values()
+
#==============================================================================
def baditer():
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index 2212fce991..b572f9abe5 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -24,7 +24,7 @@ __all__ = ["Error", "TestFailed", "TestSkipped", "ResourceDenied", "import_modul
"captured_stdout", "TransientResource", "transient_internet",
"run_with_locale", "set_memlimit", "bigmemtest", "bigaddrspacetest",
"BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
- "threading_cleanup", "reap_children"]
+ "threading_cleanup", "reap_children", "strip_python_stderr"]
class Error(Exception):
"""Base class for regression test exceptions."""
@@ -893,3 +893,13 @@ def reap_children():
break
except:
break
+
+def strip_python_stderr(stderr):
+ """Strip the stderr of a Python process from potential debug output
+ emitted by the interpreter.
+
+ This will typically be run on the result of the communicate() method
+ of a subprocess.Popen object.
+ """
+ stderr = re.sub(br"\[\d+ refs\]\r?\n?$", b"", stderr).strip()
+ return stderr
diff --git a/Lib/test/test_symtable.py b/Lib/test/test_symtable.py
index 0b4190d2d5..4b54f55afd 100644
--- a/Lib/test/test_symtable.py
+++ b/Lib/test/test_symtable.py
@@ -105,10 +105,11 @@ class SymtableTest(unittest.TestCase):
def test_function_info(self):
func = self.spam
- self.assertEqual(func.get_parameters(), ("a", "b", "kw", "var"))
- self.assertEqual(func.get_locals(),
+ self.assertEqual(
+ tuple(sorted(func.get_parameters())), ("a", "b", "kw", "var"))
+ self.assertEqual(tuple(sorted(func.get_locals())),
("a", "b", "bar", "internal", "kw", "var", "x"))
- self.assertEqual(func.get_globals(), ("bar", "glob"))
+ self.assertEqual(tuple(sorted(func.get_globals())), ("bar", "glob"))
self.assertEqual(self.internal.get_frees(), ("x",))
def test_globals(self):
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index fd6fb2bda1..e82569ac02 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -384,7 +384,7 @@ class SysModuleTest(unittest.TestCase):
attrs = ("debug", "py3k_warning", "division_warning", "division_new",
"inspect", "interactive", "optimize", "dont_write_bytecode",
"no_site", "ignore_environment", "tabcheck", "verbose",
- "unicode", "bytes_warning")
+ "unicode", "bytes_warning", "hash_randomization")
for attr in attrs:
self.assert_(hasattr(sys.flags, attr), attr)
self.assertEqual(type(getattr(sys.flags, attr)), int, attr)