diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/random/_generator.pyx | 22 | ||||
-rw-r--r-- | numpy/random/mtrand.pyx | 20 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 15 | ||||
-rw-r--r-- | numpy/random/tests/test_randomstate.py | 6 | ||||
-rw-r--r-- | numpy/typing/tests/data/mypy.ini | 1 | ||||
-rw-r--r-- | numpy/typing/tests/test_typing.py | 89 |
6 files changed, 100 insertions, 53 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index 7d652ce89..3033a1495 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -2,6 +2,7 @@ #cython: wraparound=False, nonecheck=False, boundscheck=False, cdivision=True, language_level=3 import operator import warnings +from collections.abc import MutableSequence from cpython.pycapsule cimport PyCapsule_IsValid, PyCapsule_GetPointer from cpython cimport (Py_INCREF, PyFloat_AsDouble) @@ -4347,14 +4348,14 @@ cdef class Generator: """ shuffle(x, axis=0) - Modify a sequence in-place by shuffling its contents. + Modify an array or sequence in-place by shuffling its contents. The order of sub-arrays is changed but their contents remains the same. Parameters ---------- - x : array_like - The array or list to be shuffled. + x : ndarray or MutableSequence + The array, list or mutable sequence to be shuffled. axis : int, optional The axis which `x` is shuffled along. Default is 0. It is only supported on `ndarray` objects. @@ -4414,7 +4415,11 @@ cdef class Generator: with self.lock, nogil: _shuffle_raw_wrap(&self._bitgen, n, 1, itemsize, stride, x_ptr, buf_ptr) - elif isinstance(x, np.ndarray) and x.ndim and x.size: + elif isinstance(x, np.ndarray): + if x.size == 0: + # shuffling is a no-op + return + x = np.swapaxes(x, 0, axis) buf = np.empty_like(x[0, ...]) with self.lock: @@ -4428,6 +4433,15 @@ cdef class Generator: x[i] = buf else: # Untyped path. + if not isinstance(x, MutableSequence): + # See gh-18206. We may decide to deprecate here in the future. + warnings.warn( + "`x` isn't a recognized object; `shuffle` is not guaranteed " + "to behave correctly. E.g., non-numpy array/tensor objects " + "with view semantics may contain duplicates after shuffling.", + UserWarning, stacklevel=2 + ) + if axis != 0: raise NotImplementedError("Axis argument is only supported " "on ndarray objects") diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index d43e7f5aa..814630c03 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -2,6 +2,7 @@ #cython: wraparound=False, nonecheck=False, boundscheck=False, cdivision=True, language_level=3 import operator import warnings +from collections.abc import MutableSequence import numpy as np @@ -4402,8 +4403,8 @@ cdef class RandomState: Parameters ---------- - x : array_like - The array or list to be shuffled. + x : ndarray or MutableSequence + The array, list or mutable sequence to be shuffled. Returns ------- @@ -4456,7 +4457,11 @@ cdef class RandomState: self._shuffle_raw(n, sizeof(np.npy_intp), stride, x_ptr, buf_ptr) else: self._shuffle_raw(n, itemsize, stride, x_ptr, buf_ptr) - elif isinstance(x, np.ndarray) and x.ndim and x.size: + elif isinstance(x, np.ndarray): + if x.size == 0: + # shuffling is a no-op + return + buf = np.empty_like(x[0, ...]) with self.lock: for i in reversed(range(1, n)): @@ -4468,6 +4473,15 @@ cdef class RandomState: x[i] = buf else: # Untyped path. + if not isinstance(x, MutableSequence): + # See gh-18206. We may decide to deprecate here in the future. + warnings.warn( + "`x` isn't a recognized object; `shuffle` is not guaranteed " + "to behave correctly. E.g., non-numpy array/tensor objects " + "with view semantics may contain duplicates after shuffling.", + UserWarning, stacklevel=2 + ) + with self.lock: for i in reversed(range(1, n)): j = random_interval(&self._bitgen, i) diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index c4fb5883c..47c81584c 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -960,6 +960,14 @@ class TestRandomDist: random.shuffle(actual, axis=-1) assert_array_equal(actual, desired) + def test_shuffle_custom_axis_empty(self): + random = Generator(MT19937(self.seed)) + desired = np.array([]).reshape((0, 6)) + for axis in (0, 1): + actual = np.array([]).reshape((0, 6)) + random.shuffle(actual, axis=axis) + assert_array_equal(actual, desired) + def test_shuffle_axis_nonsquare(self): y1 = np.arange(20).reshape(2, 10) y2 = y1.copy() @@ -993,6 +1001,11 @@ class TestRandomDist: arr = [[1, 2, 3], [4, 5, 6]] assert_raises(NotImplementedError, random.shuffle, arr, 1) + arr = np.array(3) + assert_raises(TypeError, random.shuffle, arr) + arr = np.ones((3, 2)) + assert_raises(np.AxisError, random.shuffle, arr, 2) + def test_permutation(self): random = Generator(MT19937(self.seed)) alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] @@ -1004,7 +1017,7 @@ class TestRandomDist: arr_2d = np.atleast_2d([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]).T actual = random.permutation(arr_2d) assert_array_equal(actual, np.atleast_2d(desired).T) - + bad_x_str = "abcd" assert_raises(np.AxisError, random.permutation, bad_x_str) diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py index b70a04347..7f5f08050 100644 --- a/numpy/random/tests/test_randomstate.py +++ b/numpy/random/tests/test_randomstate.py @@ -642,7 +642,7 @@ class TestRandomDist: a = np.array([42, 1, 2]) p = [None, None, None] assert_raises(ValueError, random.choice, a, p=p) - + def test_choice_p_non_contiguous(self): p = np.ones(10) / 5 p[1::2] = 3.0 @@ -699,6 +699,10 @@ class TestRandomDist: assert_equal( sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask])) + def test_shuffle_invalid_objects(self): + x = np.array(3) + assert_raises(TypeError, random.shuffle, x) + def test_permutation(self): random.seed(self.seed) alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] diff --git a/numpy/typing/tests/data/mypy.ini b/numpy/typing/tests/data/mypy.ini index 35cfbec89..548f76261 100644 --- a/numpy/typing/tests/data/mypy.ini +++ b/numpy/typing/tests/data/mypy.ini @@ -1,5 +1,6 @@ [mypy] plugins = numpy.typing.mypy_plugin +show_absolute_path = True [mypy-numpy] ignore_errors = True diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py index 18520a757..324312a92 100644 --- a/numpy/typing/tests/test_typing.py +++ b/numpy/typing/tests/test_typing.py @@ -25,15 +25,48 @@ REVEAL_DIR = os.path.join(DATA_DIR, "reveal") MYPY_INI = os.path.join(DATA_DIR, "mypy.ini") CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache") +#: A dictionary with file names as keys and lists of the mypy stdout as values. +#: To-be populated by `run_mypy`. +OUTPUT_MYPY: Dict[str, List[str]] = {} + + +def _key_func(key: str) -> str: + """Split at the first occurance of the ``:`` character. + + Windows drive-letters (*e.g.* ``C:``) are ignored herein. + """ + drive, tail = os.path.splitdrive(key) + return os.path.join(drive, tail.split(":", 1)[0]) + @pytest.mark.slow @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") -@pytest.fixture(scope="session", autouse=True) -def clear_cache() -> None: - """Clears the mypy cache before running any of the typing tests.""" +@pytest.fixture(scope="module", autouse=True) +def run_mypy() -> None: + """Clears the cache and run mypy before running any of the typing tests. + + The mypy results are cached in `OUTPUT_MYPY` for further use. + + """ if os.path.isdir(CACHE_DIR): shutil.rmtree(CACHE_DIR) + for directory in (PASS_DIR, REVEAL_DIR, FAIL_DIR): + # Run mypy + stdout, stderr, _ = api.run([ + "--config-file", + MYPY_INI, + "--cache-dir", + CACHE_DIR, + directory, + ]) + assert not stderr, directory + stdout = stdout.replace('*', '') + + # Parse the output + iterator = itertools.groupby(stdout.split("\n"), key=_key_func) + OUTPUT_MYPY.update((k, list(v)) for k, v in iterator if k) + def get_test_cases(directory): for root, _, files in os.walk(directory): @@ -54,15 +87,9 @@ def get_test_cases(directory): @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") @pytest.mark.parametrize("path", get_test_cases(PASS_DIR)) def test_success(path): - stdout, stderr, exitcode = api.run([ - "--config-file", - MYPY_INI, - "--cache-dir", - CACHE_DIR, - path, - ]) - assert exitcode == 0, stdout - assert re.match(r"Success: no issues found in \d+ source files?", stdout.strip()) + # Alias `OUTPUT_MYPY` so that it appears in the local namespace + output_mypy = OUTPUT_MYPY + assert path not in output_mypy @pytest.mark.slow @@ -71,29 +98,14 @@ def test_success(path): def test_fail(path): __tracebackhide__ = True - stdout, stderr, exitcode = api.run([ - "--config-file", - MYPY_INI, - "--cache-dir", - CACHE_DIR, - path, - ]) - assert exitcode != 0 - with open(path) as fin: lines = fin.readlines() errors = defaultdict(lambda: "") - error_lines = stdout.rstrip("\n").split("\n") - assert re.match( - r"Found \d+ errors? in \d+ files? \(checked \d+ source files?\)", - error_lines[-1].strip(), - ) - for error_line in error_lines[:-1]: - error_line = error_line.strip() - if not error_line: - continue + output_mypy = OUTPUT_MYPY + assert path in output_mypy + for error_line in output_mypy[path]: match = re.match( r"^.+\.py:(?P<lineno>\d+): (error|note): .+$", error_line, @@ -215,23 +227,12 @@ def _parse_reveals(file: IO[str]) -> List[str]: def test_reveal(path): __tracebackhide__ = True - stdout, stderr, exitcode = api.run([ - "--config-file", - MYPY_INI, - "--cache-dir", - CACHE_DIR, - path, - ]) - with open(path) as fin: lines = _parse_reveals(fin) - stdout_list = stdout.replace('*', '').split("\n") - for error_line in stdout_list: - error_line = error_line.strip() - if not error_line: - continue - + output_mypy = OUTPUT_MYPY + assert path in output_mypy + for error_line in output_mypy[path]: match = re.match( r"^.+\.py:(?P<lineno>\d+): note: .+$", error_line, |