diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/src/_compiled_base.c | 29 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 19 | ||||
-rw-r--r-- | numpy/testing/utils.py | 82 |
3 files changed, 113 insertions, 17 deletions
diff --git a/numpy/lib/src/_compiled_base.c b/numpy/lib/src/_compiled_base.c index f70cd2bab..a461613e3 100644 --- a/numpy/lib/src/_compiled_base.c +++ b/numpy/lib/src/_compiled_base.c @@ -161,14 +161,22 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds) len = PyArray_SIZE(lst); type = PyArray_DescrFromType(NPY_INTP); - /* handle empty list */ - if (len < 1) { - if (mlength == Py_None) { - minlength = 0; - } - else if (!(minlength = PyArray_PyIntAsIntp(mlength))) { + if (mlength == Py_None) { + minlength = 0; + } + else { + minlength = PyArray_PyIntAsIntp(mlength); + if (minlength <= 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "minlength must be positive"); + } goto fail; } + } + + /* handle empty list */ + if (len == 0) { if (!(ans = (PyArrayObject *)PyArray_Zeros(1, &minlength, type, 0))){ goto fail; } @@ -185,15 +193,6 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds) } ans_size = mx + 1; if (mlength != Py_None) { - if (!(minlength = PyArray_PyIntAsIntp(mlength))) { - goto fail; - } - if (minlength <= 0) { - /* superfluous, but may catch incorrect usage */ - PyErr_SetString(PyExc_ValueError, - "minlength must be positive"); - goto fail; - } if (ans_size < minlength) { ans_size = minlength; } diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 399a5a308..59db23a83 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -6,7 +6,7 @@ import numpy as np from numpy.testing import ( run_module_suite, TestCase, assert_, assert_equal, assert_array_equal, assert_almost_equal, assert_array_almost_equal, assert_raises, - assert_allclose, assert_array_max_ulp, assert_warns + assert_allclose, assert_array_max_ulp, assert_warns, assert_raises_regex ) from numpy.random import rand from numpy.lib import * @@ -1546,6 +1546,23 @@ class TestBincount(TestCase): y = np.bincount(x, minlength=5) assert_array_equal(y, np.zeros(5, dtype=int)) + def test_with_incorrect_minlength(self): + x = np.array([], dtype=int) + assert_raises_regex(TypeError, "an integer is required", + lambda: np.bincount(x, minlength="foobar")) + assert_raises_regex(ValueError, "must be positive", + lambda: np.bincount(x, minlength=-1)) + assert_raises_regex(ValueError, "must be positive", + lambda: np.bincount(x, minlength=0)) + + x = np.arange(5) + assert_raises_regex(TypeError, "an integer is required", + lambda: np.bincount(x, minlength="foobar")) + assert_raises_regex(ValueError, "minlength must be positive", + lambda: np.bincount(x, minlength=-1)) + assert_raises_regex(ValueError, "minlength must be positive", + lambda: np.bincount(x, minlength=0)) + class TestInterp(TestCase): def test_exceptions(self): diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 930900bf2..70357d835 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -23,7 +23,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', - 'assert_', 'assert_array_almost_equal_nulp', + 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 'assert_allclose', 'IgnoreException'] @@ -1030,6 +1030,7 @@ def raises(*args,**kwargs): nose = import_nose() return nose.tools.raises(*args,**kwargs) + def assert_raises(*args,**kwargs): """ assert_raises(exception_class, callable, *args, **kwargs) @@ -1045,6 +1046,85 @@ def assert_raises(*args,**kwargs): nose = import_nose() return nose.tools.assert_raises(*args,**kwargs) + +assert_raises_regex_impl = None + + +def assert_raises_regex(exception_class, expected_regexp, + callable_obj=None, *args, **kwargs): + """ + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Name of this function adheres to Python 3.2+ reference, but should work in + all versions down to 2.6. + + """ + nose = import_nose() + + global assert_raises_regex_impl + if assert_raises_regex_impl is None: + try: + # Python 3.2+ + assert_raises_regex_impl = nose.tools.assert_raises_regex + except AttributeError: + try: + # 2.7+ + assert_raises_regex_impl = nose.tools.assert_raises_regexp + except AttributeError: + # 2.6 + + # This class is copied from Python2.7 stdlib almost verbatim + class _AssertRaisesContext(object): + """A context manager used to implement TestCase.assertRaises* methods.""" + + def __init__(self, expected, expected_regexp=None): + self.expected = expected + self.expected_regexp = expected_regexp + + def failureException(self, msg): + return AssertionError(msg) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is None: + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + raise self.failureException( + "{0} not raised".format(exc_name)) + if not issubclass(exc_type, self.expected): + # let unexpected exceptions pass through + return False + self.exception = exc_value # store for later retrieval + if self.expected_regexp is None: + return True + + expected_regexp = self.expected_regexp + if isinstance(expected_regexp, basestring): + expected_regexp = re.compile(expected_regexp) + if not expected_regexp.search(str(exc_value)): + raise self.failureException( + '"%s" does not match "%s"' % + (expected_regexp.pattern, str(exc_value))) + return True + + def impl(cls, regex, callable_obj, *a, **kw): + mgr = _AssertRaisesContext(cls, regex) + if callable_obj is None: + return mgr + with mgr: + callable_obj(*a, **kw) + assert_raises_regex_impl = impl + + return assert_raises_regex_impl(exception_class, expected_regexp, + callable_obj, *args, **kwargs) + + def decorate_methods(cls, decorator, testmatch=None): """ Apply a decorator to all methods in a class matching a regular expression. |