diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2017-05-05 15:28:24 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-05 15:28:24 -0600 |
commit | d761fd6ccbc038970798eb7dfb1a5de825653ea8 (patch) | |
tree | 4aa9c85be386c7a4a061c3cab76e5f044e3ddd58 | |
parent | b7d30978eaa6c36701958afa19c09f2a0b31670d (diff) | |
parent | 36e7513edd1114c3f928be66953d4349273122c0 (diff) | |
download | numpy-d761fd6ccbc038970798eb7dfb1a5de825653ea8.tar.gz |
Merge pull request #8816 from eric-wieser/fix-r_
BUG: np.lib.index_tricks.r_ mutates its own state
-rw-r--r-- | numpy/lib/index_tricks.py | 84 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 31 | ||||
-rw-r--r-- | numpy/ma/extras.py | 57 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 15 |
4 files changed, 94 insertions, 93 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 1fd530f33..dc8eb1c4a 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -10,13 +10,11 @@ from numpy.core.numeric import ( from numpy.core.numerictypes import find_common_type, issubdtype from . import function_base -import numpy.matrixlib as matrix +import numpy.matrixlib as matrixlib from .function_base import diff from numpy.core.multiarray import ravel_multi_index, unravel_index from numpy.lib.stride_tricks import as_strided -makemat = matrix.matrix - __all__ = [ 'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_', @@ -235,48 +233,44 @@ class AxisConcatenator(object): Translates slice objects to concatenation along an axis. For detailed documentation on usage, see `r_`. - """ - - def _retval(self, res): - if self.matrix: - oldndim = res.ndim - res = makemat(res) - if oldndim == 1 and self.col: - res = res.T - self.axis = self._axis - self.matrix = self._matrix - self.col = 0 - return res + # allow ma.mr_ to override this + concatenate = staticmethod(_nx.concatenate) + makemat = staticmethod(matrixlib.matrix) def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1): - self._axis = axis - self._matrix = matrix self.axis = axis self.matrix = matrix - self.col = 0 self.trans1d = trans1d self.ndmin = ndmin def __getitem__(self, key): - trans1d = self.trans1d - ndmin = self.ndmin + # handle matrix builder syntax if isinstance(key, str): frame = sys._getframe().f_back - mymat = matrix.bmat(key, frame.f_globals, frame.f_locals) + mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals) return mymat + if not isinstance(key, tuple): key = (key,) + + # copy attributes, since they can be overriden in the first argument + trans1d = self.trans1d + ndmin = self.ndmin + matrix = self.matrix + axis = self.axis + objs = [] scalars = [] arraytypes = [] scalartypes = [] - for k in range(len(key)): + + for k, item in enumerate(key): scalar = False - if isinstance(key[k], slice): - step = key[k].step - start = key[k].start - stop = key[k].stop + if isinstance(item, slice): + step = item.step + start = item.start + stop = item.stop if start is None: start = 0 if step is None: @@ -290,37 +284,35 @@ class AxisConcatenator(object): newobj = array(newobj, copy=False, ndmin=ndmin) if trans1d != -1: newobj = newobj.swapaxes(-1, trans1d) - elif isinstance(key[k], str): + elif isinstance(item, str): if k != 0: raise ValueError("special directives must be the " "first entry.") - key0 = key[0] - if key0 in 'rc': - self.matrix = True - self.col = (key0 == 'c') + if item in ('r', 'c'): + matrix = True + col = (item == 'c') continue - if ',' in key0: - vec = key0.split(',') + if ',' in item: + vec = item.split(',') try: - self.axis, ndmin = \ - [int(x) for x in vec[:2]] + axis, ndmin = [int(x) for x in vec[:2]] if len(vec) == 3: trans1d = int(vec[2]) continue except: raise ValueError("unknown special directive") try: - self.axis = int(key[k]) + axis = int(item) continue except (ValueError, TypeError): raise ValueError("unknown special directive") - elif type(key[k]) in ScalarType: - newobj = array(key[k], ndmin=ndmin) - scalars.append(k) + elif type(item) in ScalarType: + newobj = array(item, ndmin=ndmin) + scalars.append(len(objs)) scalar = True scalartypes.append(newobj.dtype) else: - newobj = key[k] + newobj = item if ndmin > 1: tempobj = array(newobj, copy=False, subok=True) newobj = array(newobj, copy=False, subok=True, @@ -339,14 +331,20 @@ class AxisConcatenator(object): if not scalar and isinstance(newobj, _nx.ndarray): arraytypes.append(newobj.dtype) - # Esure that scalars won't up-cast unless warranted + # Ensure that scalars won't up-cast unless warranted final_dtype = find_common_type(arraytypes, scalartypes) if final_dtype is not None: for k in scalars: objs[k] = objs[k].astype(final_dtype) - res = _nx.concatenate(tuple(objs), axis=self.axis) - return self._retval(res) + res = self.concatenate(tuple(objs), axis=axis) + + if matrix: + oldndim = res.ndim + res = self.makemat(res) + if oldndim == 1 and col: + res = res.T + return res def __len__(self): return 0 diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index d9fa1f43e..5b791026b 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -174,6 +174,37 @@ class TestConcatenator(TestCase): assert_array_equal(d[:5, :], b) assert_array_equal(d[5:, :], c) + def test_matrix(self): + a = [1, 2] + b = [3, 4] + + ab_r = np.r_['r', a, b] + ab_c = np.r_['c', a, b] + + assert_equal(type(ab_r), np.matrix) + assert_equal(type(ab_c), np.matrix) + + assert_equal(np.array(ab_r), [[1,2,3,4]]) + assert_equal(np.array(ab_c), [[1],[2],[3],[4]]) + + assert_raises(ValueError, lambda: np.r_['rc', a, b]) + + def test_matrix_scalar(self): + r = np.r_['r', [1, 2], 3] + assert_equal(type(r), np.matrix) + assert_equal(np.array(r), [[1,2,3]]) + + def test_matrix_builder(self): + a = np.array([1]) + b = np.array([2]) + c = np.array([3]) + d = np.array([4]) + actual = np.r_['a, b; c, d'] + expected = np.bmat([[a, b], [c, d]]) + + assert_equal(actual, expected) + assert_equal(type(actual), type(expected)) + class TestNdenumerate(TestCase): def test_basic(self): diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index e100e471c..d55e0d1ea 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -1461,60 +1461,19 @@ class MAxisConcatenator(AxisConcatenator): mr_class """ + concatenate = staticmethod(concatenate) - def __init__(self, axis=0): - AxisConcatenator.__init__(self, axis, matrix=False) + @staticmethod + def makemat(arr): + return array(arr.data.view(np.matrix), mask=arr.mask) def __getitem__(self, key): + # matrix builder syntax, like 'a, b; c, d' if isinstance(key, str): raise MAError("Unavailable for masked array.") - if not isinstance(key, tuple): - key = (key,) - objs = [] - scalars = [] - final_dtypedescr = None - for k in range(len(key)): - scalar = False - if isinstance(key[k], slice): - step = key[k].step - start = key[k].start - stop = key[k].stop - if start is None: - start = 0 - if step is None: - step = 1 - if isinstance(step, complex): - size = int(abs(step)) - newobj = np.linspace(start, stop, num=size) - else: - newobj = np.arange(start, stop, step) - elif isinstance(key[k], str): - if (key[k] in 'rc'): - self.matrix = True - self.col = (key[k] == 'c') - continue - try: - self.axis = int(key[k]) - continue - except (ValueError, TypeError): - raise ValueError("Unknown special directive") - elif type(key[k]) in np.ScalarType: - newobj = asarray([key[k]]) - scalars.append(k) - scalar = True - else: - newobj = key[k] - objs.append(newobj) - if isinstance(newobj, ndarray) and not scalar: - if final_dtypedescr is None: - final_dtypedescr = newobj.dtype - elif newobj.dtype > final_dtypedescr: - final_dtypedescr = newobj.dtype - if final_dtypedescr is not None: - for k in scalars: - objs[k] = objs[k].astype(final_dtypedescr) - res = concatenate(tuple(objs), axis=self.axis) - return self._retval(res) + + return super(MAxisConcatenator, self).__getitem__(key) + class mr_class(MAxisConcatenator): """ diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index e7ebd8b82..4b7fe07b6 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -14,7 +14,8 @@ import itertools import numpy as np from numpy.testing import ( - TestCase, run_module_suite, assert_warns, suppress_warnings + TestCase, run_module_suite, assert_warns, suppress_warnings, + assert_raises ) from numpy.ma.testutils import ( assert_, assert_array_equal, assert_equal, assert_almost_equal @@ -304,6 +305,18 @@ class TestConcatenator(TestCase): assert_array_equal(d[5:,:], b_2) assert_array_equal(d.mask, np.r_[m_1, m_2]) + def test_matrix_builder(self): + assert_raises(np.ma.MAError, lambda: mr_['1, 2; 3, 4']) + + def test_matrix(self): + actual = mr_['r', 1, 2, 3] + expected = np.ma.array(np.r_['r', 1, 2, 3]) + assert_array_equal(actual, expected) + + # outer type is masked array, inner type is matrix + assert_equal(type(actual), type(expected)) + assert_equal(type(actual.data), type(expected.data)) + class TestNotMasked(TestCase): # Tests notmasked_edges and notmasked_contiguous. |