diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2017-05-05 15:28:24 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-05 15:28:24 -0600 |
commit | d761fd6ccbc038970798eb7dfb1a5de825653ea8 (patch) | |
tree | 4aa9c85be386c7a4a061c3cab76e5f044e3ddd58 /numpy/lib/index_tricks.py | |
parent | b7d30978eaa6c36701958afa19c09f2a0b31670d (diff) | |
parent | 36e7513edd1114c3f928be66953d4349273122c0 (diff) | |
download | numpy-d761fd6ccbc038970798eb7dfb1a5de825653ea8.tar.gz |
Merge pull request #8816 from eric-wieser/fix-r_
BUG: np.lib.index_tricks.r_ mutates its own state
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 84 |
1 files changed, 41 insertions, 43 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 1fd530f33..dc8eb1c4a 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -10,13 +10,11 @@ from numpy.core.numeric import ( from numpy.core.numerictypes import find_common_type, issubdtype from . import function_base -import numpy.matrixlib as matrix +import numpy.matrixlib as matrixlib from .function_base import diff from numpy.core.multiarray import ravel_multi_index, unravel_index from numpy.lib.stride_tricks import as_strided -makemat = matrix.matrix - __all__ = [ 'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_', @@ -235,48 +233,44 @@ class AxisConcatenator(object): Translates slice objects to concatenation along an axis. For detailed documentation on usage, see `r_`. - """ - - def _retval(self, res): - if self.matrix: - oldndim = res.ndim - res = makemat(res) - if oldndim == 1 and self.col: - res = res.T - self.axis = self._axis - self.matrix = self._matrix - self.col = 0 - return res + # allow ma.mr_ to override this + concatenate = staticmethod(_nx.concatenate) + makemat = staticmethod(matrixlib.matrix) def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1): - self._axis = axis - self._matrix = matrix self.axis = axis self.matrix = matrix - self.col = 0 self.trans1d = trans1d self.ndmin = ndmin def __getitem__(self, key): - trans1d = self.trans1d - ndmin = self.ndmin + # handle matrix builder syntax if isinstance(key, str): frame = sys._getframe().f_back - mymat = matrix.bmat(key, frame.f_globals, frame.f_locals) + mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals) return mymat + if not isinstance(key, tuple): key = (key,) + + # copy attributes, since they can be overriden in the first argument + trans1d = self.trans1d + ndmin = self.ndmin + matrix = self.matrix + axis = self.axis + objs = [] scalars = [] arraytypes = [] scalartypes = [] - for k in range(len(key)): + + for k, item in enumerate(key): scalar = False - if isinstance(key[k], slice): - step = key[k].step - start = key[k].start - stop = key[k].stop + if isinstance(item, slice): + step = item.step + start = item.start + stop = item.stop if start is None: start = 0 if step is None: @@ -290,37 +284,35 @@ class AxisConcatenator(object): newobj = array(newobj, copy=False, ndmin=ndmin) if trans1d != -1: newobj = newobj.swapaxes(-1, trans1d) - elif isinstance(key[k], str): + elif isinstance(item, str): if k != 0: raise ValueError("special directives must be the " "first entry.") - key0 = key[0] - if key0 in 'rc': - self.matrix = True - self.col = (key0 == 'c') + if item in ('r', 'c'): + matrix = True + col = (item == 'c') continue - if ',' in key0: - vec = key0.split(',') + if ',' in item: + vec = item.split(',') try: - self.axis, ndmin = \ - [int(x) for x in vec[:2]] + axis, ndmin = [int(x) for x in vec[:2]] if len(vec) == 3: trans1d = int(vec[2]) continue except: raise ValueError("unknown special directive") try: - self.axis = int(key[k]) + axis = int(item) continue except (ValueError, TypeError): raise ValueError("unknown special directive") - elif type(key[k]) in ScalarType: - newobj = array(key[k], ndmin=ndmin) - scalars.append(k) + elif type(item) in ScalarType: + newobj = array(item, ndmin=ndmin) + scalars.append(len(objs)) scalar = True scalartypes.append(newobj.dtype) else: - newobj = key[k] + newobj = item if ndmin > 1: tempobj = array(newobj, copy=False, subok=True) newobj = array(newobj, copy=False, subok=True, @@ -339,14 +331,20 @@ class AxisConcatenator(object): if not scalar and isinstance(newobj, _nx.ndarray): arraytypes.append(newobj.dtype) - # Esure that scalars won't up-cast unless warranted + # Ensure that scalars won't up-cast unless warranted final_dtype = find_common_type(arraytypes, scalartypes) if final_dtype is not None: for k in scalars: objs[k] = objs[k].astype(final_dtype) - res = _nx.concatenate(tuple(objs), axis=self.axis) - return self._retval(res) + res = self.concatenate(tuple(objs), axis=axis) + + if matrix: + oldndim = res.ndim + res = self.makemat(res) + if oldndim == 1 and col: + res = res.T + return res def __len__(self): return 0 |