summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/random/_generator.pyx22
-rw-r--r--numpy/random/mtrand.pyx20
-rw-r--r--numpy/random/tests/test_generator_mt19937.py15
-rw-r--r--numpy/random/tests/test_randomstate.py6
-rw-r--r--numpy/typing/tests/data/mypy.ini1
-rw-r--r--numpy/typing/tests/test_typing.py89
6 files changed, 100 insertions, 53 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index 7d652ce89..3033a1495 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -2,6 +2,7 @@
#cython: wraparound=False, nonecheck=False, boundscheck=False, cdivision=True, language_level=3
import operator
import warnings
+from collections.abc import MutableSequence
from cpython.pycapsule cimport PyCapsule_IsValid, PyCapsule_GetPointer
from cpython cimport (Py_INCREF, PyFloat_AsDouble)
@@ -4347,14 +4348,14 @@ cdef class Generator:
"""
shuffle(x, axis=0)
- Modify a sequence in-place by shuffling its contents.
+ Modify an array or sequence in-place by shuffling its contents.
The order of sub-arrays is changed but their contents remains the same.
Parameters
----------
- x : array_like
- The array or list to be shuffled.
+ x : ndarray or MutableSequence
+ The array, list or mutable sequence to be shuffled.
axis : int, optional
The axis which `x` is shuffled along. Default is 0.
It is only supported on `ndarray` objects.
@@ -4414,7 +4415,11 @@ cdef class Generator:
with self.lock, nogil:
_shuffle_raw_wrap(&self._bitgen, n, 1, itemsize, stride,
x_ptr, buf_ptr)
- elif isinstance(x, np.ndarray) and x.ndim and x.size:
+ elif isinstance(x, np.ndarray):
+ if x.size == 0:
+ # shuffling is a no-op
+ return
+
x = np.swapaxes(x, 0, axis)
buf = np.empty_like(x[0, ...])
with self.lock:
@@ -4428,6 +4433,15 @@ cdef class Generator:
x[i] = buf
else:
# Untyped path.
+ if not isinstance(x, MutableSequence):
+ # See gh-18206. We may decide to deprecate here in the future.
+ warnings.warn(
+ "`x` isn't a recognized object; `shuffle` is not guaranteed "
+ "to behave correctly. E.g., non-numpy array/tensor objects "
+ "with view semantics may contain duplicates after shuffling.",
+ UserWarning, stacklevel=2
+ )
+
if axis != 0:
raise NotImplementedError("Axis argument is only supported "
"on ndarray objects")
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index d43e7f5aa..814630c03 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -2,6 +2,7 @@
#cython: wraparound=False, nonecheck=False, boundscheck=False, cdivision=True, language_level=3
import operator
import warnings
+from collections.abc import MutableSequence
import numpy as np
@@ -4402,8 +4403,8 @@ cdef class RandomState:
Parameters
----------
- x : array_like
- The array or list to be shuffled.
+ x : ndarray or MutableSequence
+ The array, list or mutable sequence to be shuffled.
Returns
-------
@@ -4456,7 +4457,11 @@ cdef class RandomState:
self._shuffle_raw(n, sizeof(np.npy_intp), stride, x_ptr, buf_ptr)
else:
self._shuffle_raw(n, itemsize, stride, x_ptr, buf_ptr)
- elif isinstance(x, np.ndarray) and x.ndim and x.size:
+ elif isinstance(x, np.ndarray):
+ if x.size == 0:
+ # shuffling is a no-op
+ return
+
buf = np.empty_like(x[0, ...])
with self.lock:
for i in reversed(range(1, n)):
@@ -4468,6 +4473,15 @@ cdef class RandomState:
x[i] = buf
else:
# Untyped path.
+ if not isinstance(x, MutableSequence):
+ # See gh-18206. We may decide to deprecate here in the future.
+ warnings.warn(
+ "`x` isn't a recognized object; `shuffle` is not guaranteed "
+ "to behave correctly. E.g., non-numpy array/tensor objects "
+ "with view semantics may contain duplicates after shuffling.",
+ UserWarning, stacklevel=2
+ )
+
with self.lock:
for i in reversed(range(1, n)):
j = random_interval(&self._bitgen, i)
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index c4fb5883c..47c81584c 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -960,6 +960,14 @@ class TestRandomDist:
random.shuffle(actual, axis=-1)
assert_array_equal(actual, desired)
+ def test_shuffle_custom_axis_empty(self):
+ random = Generator(MT19937(self.seed))
+ desired = np.array([]).reshape((0, 6))
+ for axis in (0, 1):
+ actual = np.array([]).reshape((0, 6))
+ random.shuffle(actual, axis=axis)
+ assert_array_equal(actual, desired)
+
def test_shuffle_axis_nonsquare(self):
y1 = np.arange(20).reshape(2, 10)
y2 = y1.copy()
@@ -993,6 +1001,11 @@ class TestRandomDist:
arr = [[1, 2, 3], [4, 5, 6]]
assert_raises(NotImplementedError, random.shuffle, arr, 1)
+ arr = np.array(3)
+ assert_raises(TypeError, random.shuffle, arr)
+ arr = np.ones((3, 2))
+ assert_raises(np.AxisError, random.shuffle, arr, 2)
+
def test_permutation(self):
random = Generator(MT19937(self.seed))
alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
@@ -1004,7 +1017,7 @@ class TestRandomDist:
arr_2d = np.atleast_2d([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]).T
actual = random.permutation(arr_2d)
assert_array_equal(actual, np.atleast_2d(desired).T)
-
+
bad_x_str = "abcd"
assert_raises(np.AxisError, random.permutation, bad_x_str)
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index b70a04347..7f5f08050 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -642,7 +642,7 @@ class TestRandomDist:
a = np.array([42, 1, 2])
p = [None, None, None]
assert_raises(ValueError, random.choice, a, p=p)
-
+
def test_choice_p_non_contiguous(self):
p = np.ones(10) / 5
p[1::2] = 3.0
@@ -699,6 +699,10 @@ class TestRandomDist:
assert_equal(
sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask]))
+ def test_shuffle_invalid_objects(self):
+ x = np.array(3)
+ assert_raises(TypeError, random.shuffle, x)
+
def test_permutation(self):
random.seed(self.seed)
alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
diff --git a/numpy/typing/tests/data/mypy.ini b/numpy/typing/tests/data/mypy.ini
index 35cfbec89..548f76261 100644
--- a/numpy/typing/tests/data/mypy.ini
+++ b/numpy/typing/tests/data/mypy.ini
@@ -1,5 +1,6 @@
[mypy]
plugins = numpy.typing.mypy_plugin
+show_absolute_path = True
[mypy-numpy]
ignore_errors = True
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py
index 18520a757..324312a92 100644
--- a/numpy/typing/tests/test_typing.py
+++ b/numpy/typing/tests/test_typing.py
@@ -25,15 +25,48 @@ REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
MYPY_INI = os.path.join(DATA_DIR, "mypy.ini")
CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")
+#: A dictionary with file names as keys and lists of the mypy stdout as values.
+#: To-be populated by `run_mypy`.
+OUTPUT_MYPY: Dict[str, List[str]] = {}
+
+
+def _key_func(key: str) -> str:
+ """Split at the first occurance of the ``:`` character.
+
+ Windows drive-letters (*e.g.* ``C:``) are ignored herein.
+ """
+ drive, tail = os.path.splitdrive(key)
+ return os.path.join(drive, tail.split(":", 1)[0])
+
@pytest.mark.slow
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
-@pytest.fixture(scope="session", autouse=True)
-def clear_cache() -> None:
- """Clears the mypy cache before running any of the typing tests."""
+@pytest.fixture(scope="module", autouse=True)
+def run_mypy() -> None:
+ """Clears the cache and run mypy before running any of the typing tests.
+
+ The mypy results are cached in `OUTPUT_MYPY` for further use.
+
+ """
if os.path.isdir(CACHE_DIR):
shutil.rmtree(CACHE_DIR)
+ for directory in (PASS_DIR, REVEAL_DIR, FAIL_DIR):
+ # Run mypy
+ stdout, stderr, _ = api.run([
+ "--config-file",
+ MYPY_INI,
+ "--cache-dir",
+ CACHE_DIR,
+ directory,
+ ])
+ assert not stderr, directory
+ stdout = stdout.replace('*', '')
+
+ # Parse the output
+ iterator = itertools.groupby(stdout.split("\n"), key=_key_func)
+ OUTPUT_MYPY.update((k, list(v)) for k, v in iterator if k)
+
def get_test_cases(directory):
for root, _, files in os.walk(directory):
@@ -54,15 +87,9 @@ def get_test_cases(directory):
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
def test_success(path):
- stdout, stderr, exitcode = api.run([
- "--config-file",
- MYPY_INI,
- "--cache-dir",
- CACHE_DIR,
- path,
- ])
- assert exitcode == 0, stdout
- assert re.match(r"Success: no issues found in \d+ source files?", stdout.strip())
+ # Alias `OUTPUT_MYPY` so that it appears in the local namespace
+ output_mypy = OUTPUT_MYPY
+ assert path not in output_mypy
@pytest.mark.slow
@@ -71,29 +98,14 @@ def test_success(path):
def test_fail(path):
__tracebackhide__ = True
- stdout, stderr, exitcode = api.run([
- "--config-file",
- MYPY_INI,
- "--cache-dir",
- CACHE_DIR,
- path,
- ])
- assert exitcode != 0
-
with open(path) as fin:
lines = fin.readlines()
errors = defaultdict(lambda: "")
- error_lines = stdout.rstrip("\n").split("\n")
- assert re.match(
- r"Found \d+ errors? in \d+ files? \(checked \d+ source files?\)",
- error_lines[-1].strip(),
- )
- for error_line in error_lines[:-1]:
- error_line = error_line.strip()
- if not error_line:
- continue
+ output_mypy = OUTPUT_MYPY
+ assert path in output_mypy
+ for error_line in output_mypy[path]:
match = re.match(
r"^.+\.py:(?P<lineno>\d+): (error|note): .+$",
error_line,
@@ -215,23 +227,12 @@ def _parse_reveals(file: IO[str]) -> List[str]:
def test_reveal(path):
__tracebackhide__ = True
- stdout, stderr, exitcode = api.run([
- "--config-file",
- MYPY_INI,
- "--cache-dir",
- CACHE_DIR,
- path,
- ])
-
with open(path) as fin:
lines = _parse_reveals(fin)
- stdout_list = stdout.replace('*', '').split("\n")
- for error_line in stdout_list:
- error_line = error_line.strip()
- if not error_line:
- continue
-
+ output_mypy = OUTPUT_MYPY
+ assert path in output_mypy
+ for error_line in output_mypy[path]:
match = re.match(
r"^.+\.py:(?P<lineno>\d+): note: .+$",
error_line,