summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/src/_compiled_base.c29
-rw-r--r--numpy/lib/tests/test_function_base.py19
-rw-r--r--numpy/testing/utils.py82
3 files changed, 113 insertions, 17 deletions
diff --git a/numpy/lib/src/_compiled_base.c b/numpy/lib/src/_compiled_base.c
index f70cd2bab..a461613e3 100644
--- a/numpy/lib/src/_compiled_base.c
+++ b/numpy/lib/src/_compiled_base.c
@@ -161,14 +161,22 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
len = PyArray_SIZE(lst);
type = PyArray_DescrFromType(NPY_INTP);
- /* handle empty list */
- if (len < 1) {
- if (mlength == Py_None) {
- minlength = 0;
- }
- else if (!(minlength = PyArray_PyIntAsIntp(mlength))) {
+ if (mlength == Py_None) {
+ minlength = 0;
+ }
+ else {
+ minlength = PyArray_PyIntAsIntp(mlength);
+ if (minlength <= 0) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_ValueError,
+ "minlength must be positive");
+ }
goto fail;
}
+ }
+
+ /* handle empty list */
+ if (len == 0) {
if (!(ans = (PyArrayObject *)PyArray_Zeros(1, &minlength, type, 0))){
goto fail;
}
@@ -185,15 +193,6 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
}
ans_size = mx + 1;
if (mlength != Py_None) {
- if (!(minlength = PyArray_PyIntAsIntp(mlength))) {
- goto fail;
- }
- if (minlength <= 0) {
- /* superfluous, but may catch incorrect usage */
- PyErr_SetString(PyExc_ValueError,
- "minlength must be positive");
- goto fail;
- }
if (ans_size < minlength) {
ans_size = minlength;
}
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 399a5a308..59db23a83 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -6,7 +6,7 @@ import numpy as np
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
assert_almost_equal, assert_array_almost_equal, assert_raises,
- assert_allclose, assert_array_max_ulp, assert_warns
+ assert_allclose, assert_array_max_ulp, assert_warns, assert_raises_regex
)
from numpy.random import rand
from numpy.lib import *
@@ -1546,6 +1546,23 @@ class TestBincount(TestCase):
y = np.bincount(x, minlength=5)
assert_array_equal(y, np.zeros(5, dtype=int))
+ def test_with_incorrect_minlength(self):
+ x = np.array([], dtype=int)
+ assert_raises_regex(TypeError, "an integer is required",
+ lambda: np.bincount(x, minlength="foobar"))
+ assert_raises_regex(ValueError, "must be positive",
+ lambda: np.bincount(x, minlength=-1))
+ assert_raises_regex(ValueError, "must be positive",
+ lambda: np.bincount(x, minlength=0))
+
+ x = np.arange(5)
+ assert_raises_regex(TypeError, "an integer is required",
+ lambda: np.bincount(x, minlength="foobar"))
+ assert_raises_regex(ValueError, "minlength must be positive",
+ lambda: np.bincount(x, minlength=-1))
+ assert_raises_regex(ValueError, "minlength must be positive",
+ lambda: np.bincount(x, minlength=0))
+
class TestInterp(TestCase):
def test_exceptions(self):
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 930900bf2..70357d835 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -23,7 +23,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
- 'assert_', 'assert_array_almost_equal_nulp',
+ 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
'assert_allclose', 'IgnoreException']
@@ -1030,6 +1030,7 @@ def raises(*args,**kwargs):
nose = import_nose()
return nose.tools.raises(*args,**kwargs)
+
def assert_raises(*args,**kwargs):
"""
assert_raises(exception_class, callable, *args, **kwargs)
@@ -1045,6 +1046,85 @@ def assert_raises(*args,**kwargs):
nose = import_nose()
return nose.tools.assert_raises(*args,**kwargs)
+
+assert_raises_regex_impl = None
+
+
+def assert_raises_regex(exception_class, expected_regexp,
+ callable_obj=None, *args, **kwargs):
+ """
+ Fail unless an exception of class exception_class and with message that
+ matches expected_regexp is thrown by callable when invoked with arguments
+ args and keyword arguments kwargs.
+
+ Name of this function adheres to Python 3.2+ reference, but should work in
+ all versions down to 2.6.
+
+ """
+ nose = import_nose()
+
+ global assert_raises_regex_impl
+ if assert_raises_regex_impl is None:
+ try:
+ # Python 3.2+
+ assert_raises_regex_impl = nose.tools.assert_raises_regex
+ except AttributeError:
+ try:
+ # 2.7+
+ assert_raises_regex_impl = nose.tools.assert_raises_regexp
+ except AttributeError:
+ # 2.6
+
+ # This class is copied from Python2.7 stdlib almost verbatim
+ class _AssertRaisesContext(object):
+ """A context manager used to implement TestCase.assertRaises* methods."""
+
+ def __init__(self, expected, expected_regexp=None):
+ self.expected = expected
+ self.expected_regexp = expected_regexp
+
+ def failureException(self, msg):
+ return AssertionError(msg)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_type is None:
+ try:
+ exc_name = self.expected.__name__
+ except AttributeError:
+ exc_name = str(self.expected)
+ raise self.failureException(
+ "{0} not raised".format(exc_name))
+ if not issubclass(exc_type, self.expected):
+ # let unexpected exceptions pass through
+ return False
+ self.exception = exc_value # store for later retrieval
+ if self.expected_regexp is None:
+ return True
+
+ expected_regexp = self.expected_regexp
+ if isinstance(expected_regexp, basestring):
+ expected_regexp = re.compile(expected_regexp)
+ if not expected_regexp.search(str(exc_value)):
+ raise self.failureException(
+ '"%s" does not match "%s"' %
+ (expected_regexp.pattern, str(exc_value)))
+ return True
+
+ def impl(cls, regex, callable_obj, *a, **kw):
+ mgr = _AssertRaisesContext(cls, regex)
+ if callable_obj is None:
+ return mgr
+ with mgr:
+ callable_obj(*a, **kw)
+ assert_raises_regex_impl = impl
+
+ return assert_raises_regex_impl(exception_class, expected_regexp,
+ callable_obj, *args, **kwargs)
+
+
def decorate_methods(cls, decorator, testmatch=None):
"""
Apply a decorator to all methods in a class matching a regular expression.