summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2017-05-05 15:28:24 -0600
committerGitHub <noreply@github.com>2017-05-05 15:28:24 -0600
commitd761fd6ccbc038970798eb7dfb1a5de825653ea8 (patch)
tree4aa9c85be386c7a4a061c3cab76e5f044e3ddd58
parentb7d30978eaa6c36701958afa19c09f2a0b31670d (diff)
parent36e7513edd1114c3f928be66953d4349273122c0 (diff)
downloadnumpy-d761fd6ccbc038970798eb7dfb1a5de825653ea8.tar.gz
Merge pull request #8816 from eric-wieser/fix-r_
BUG: np.lib.index_tricks.r_ mutates its own state
-rw-r--r--numpy/lib/index_tricks.py84
-rw-r--r--numpy/lib/tests/test_index_tricks.py31
-rw-r--r--numpy/ma/extras.py57
-rw-r--r--numpy/ma/tests/test_extras.py15
4 files changed, 94 insertions, 93 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 1fd530f33..dc8eb1c4a 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -10,13 +10,11 @@ from numpy.core.numeric import (
from numpy.core.numerictypes import find_common_type, issubdtype
from . import function_base
-import numpy.matrixlib as matrix
+import numpy.matrixlib as matrixlib
from .function_base import diff
from numpy.core.multiarray import ravel_multi_index, unravel_index
from numpy.lib.stride_tricks import as_strided
-makemat = matrix.matrix
-
__all__ = [
'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
@@ -235,48 +233,44 @@ class AxisConcatenator(object):
Translates slice objects to concatenation along an axis.
For detailed documentation on usage, see `r_`.
-
"""
-
- def _retval(self, res):
- if self.matrix:
- oldndim = res.ndim
- res = makemat(res)
- if oldndim == 1 and self.col:
- res = res.T
- self.axis = self._axis
- self.matrix = self._matrix
- self.col = 0
- return res
+ # allow ma.mr_ to override this
+ concatenate = staticmethod(_nx.concatenate)
+ makemat = staticmethod(matrixlib.matrix)
def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1):
- self._axis = axis
- self._matrix = matrix
self.axis = axis
self.matrix = matrix
- self.col = 0
self.trans1d = trans1d
self.ndmin = ndmin
def __getitem__(self, key):
- trans1d = self.trans1d
- ndmin = self.ndmin
+ # handle matrix builder syntax
if isinstance(key, str):
frame = sys._getframe().f_back
- mymat = matrix.bmat(key, frame.f_globals, frame.f_locals)
+ mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals)
return mymat
+
if not isinstance(key, tuple):
key = (key,)
+
+ # copy attributes, since they can be overriden in the first argument
+ trans1d = self.trans1d
+ ndmin = self.ndmin
+ matrix = self.matrix
+ axis = self.axis
+
objs = []
scalars = []
arraytypes = []
scalartypes = []
- for k in range(len(key)):
+
+ for k, item in enumerate(key):
scalar = False
- if isinstance(key[k], slice):
- step = key[k].step
- start = key[k].start
- stop = key[k].stop
+ if isinstance(item, slice):
+ step = item.step
+ start = item.start
+ stop = item.stop
if start is None:
start = 0
if step is None:
@@ -290,37 +284,35 @@ class AxisConcatenator(object):
newobj = array(newobj, copy=False, ndmin=ndmin)
if trans1d != -1:
newobj = newobj.swapaxes(-1, trans1d)
- elif isinstance(key[k], str):
+ elif isinstance(item, str):
if k != 0:
raise ValueError("special directives must be the "
"first entry.")
- key0 = key[0]
- if key0 in 'rc':
- self.matrix = True
- self.col = (key0 == 'c')
+ if item in ('r', 'c'):
+ matrix = True
+ col = (item == 'c')
continue
- if ',' in key0:
- vec = key0.split(',')
+ if ',' in item:
+ vec = item.split(',')
try:
- self.axis, ndmin = \
- [int(x) for x in vec[:2]]
+ axis, ndmin = [int(x) for x in vec[:2]]
if len(vec) == 3:
trans1d = int(vec[2])
continue
except:
raise ValueError("unknown special directive")
try:
- self.axis = int(key[k])
+ axis = int(item)
continue
except (ValueError, TypeError):
raise ValueError("unknown special directive")
- elif type(key[k]) in ScalarType:
- newobj = array(key[k], ndmin=ndmin)
- scalars.append(k)
+ elif type(item) in ScalarType:
+ newobj = array(item, ndmin=ndmin)
+ scalars.append(len(objs))
scalar = True
scalartypes.append(newobj.dtype)
else:
- newobj = key[k]
+ newobj = item
if ndmin > 1:
tempobj = array(newobj, copy=False, subok=True)
newobj = array(newobj, copy=False, subok=True,
@@ -339,14 +331,20 @@ class AxisConcatenator(object):
if not scalar and isinstance(newobj, _nx.ndarray):
arraytypes.append(newobj.dtype)
- # Esure that scalars won't up-cast unless warranted
+ # Ensure that scalars won't up-cast unless warranted
final_dtype = find_common_type(arraytypes, scalartypes)
if final_dtype is not None:
for k in scalars:
objs[k] = objs[k].astype(final_dtype)
- res = _nx.concatenate(tuple(objs), axis=self.axis)
- return self._retval(res)
+ res = self.concatenate(tuple(objs), axis=axis)
+
+ if matrix:
+ oldndim = res.ndim
+ res = self.makemat(res)
+ if oldndim == 1 and col:
+ res = res.T
+ return res
def __len__(self):
return 0
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index d9fa1f43e..5b791026b 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -174,6 +174,37 @@ class TestConcatenator(TestCase):
assert_array_equal(d[:5, :], b)
assert_array_equal(d[5:, :], c)
+ def test_matrix(self):
+ a = [1, 2]
+ b = [3, 4]
+
+ ab_r = np.r_['r', a, b]
+ ab_c = np.r_['c', a, b]
+
+ assert_equal(type(ab_r), np.matrix)
+ assert_equal(type(ab_c), np.matrix)
+
+ assert_equal(np.array(ab_r), [[1,2,3,4]])
+ assert_equal(np.array(ab_c), [[1],[2],[3],[4]])
+
+ assert_raises(ValueError, lambda: np.r_['rc', a, b])
+
+ def test_matrix_scalar(self):
+ r = np.r_['r', [1, 2], 3]
+ assert_equal(type(r), np.matrix)
+ assert_equal(np.array(r), [[1,2,3]])
+
+ def test_matrix_builder(self):
+ a = np.array([1])
+ b = np.array([2])
+ c = np.array([3])
+ d = np.array([4])
+ actual = np.r_['a, b; c, d']
+ expected = np.bmat([[a, b], [c, d]])
+
+ assert_equal(actual, expected)
+ assert_equal(type(actual), type(expected))
+
class TestNdenumerate(TestCase):
def test_basic(self):
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index e100e471c..d55e0d1ea 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1461,60 +1461,19 @@ class MAxisConcatenator(AxisConcatenator):
mr_class
"""
+ concatenate = staticmethod(concatenate)
- def __init__(self, axis=0):
- AxisConcatenator.__init__(self, axis, matrix=False)
+ @staticmethod
+ def makemat(arr):
+ return array(arr.data.view(np.matrix), mask=arr.mask)
def __getitem__(self, key):
+ # matrix builder syntax, like 'a, b; c, d'
if isinstance(key, str):
raise MAError("Unavailable for masked array.")
- if not isinstance(key, tuple):
- key = (key,)
- objs = []
- scalars = []
- final_dtypedescr = None
- for k in range(len(key)):
- scalar = False
- if isinstance(key[k], slice):
- step = key[k].step
- start = key[k].start
- stop = key[k].stop
- if start is None:
- start = 0
- if step is None:
- step = 1
- if isinstance(step, complex):
- size = int(abs(step))
- newobj = np.linspace(start, stop, num=size)
- else:
- newobj = np.arange(start, stop, step)
- elif isinstance(key[k], str):
- if (key[k] in 'rc'):
- self.matrix = True
- self.col = (key[k] == 'c')
- continue
- try:
- self.axis = int(key[k])
- continue
- except (ValueError, TypeError):
- raise ValueError("Unknown special directive")
- elif type(key[k]) in np.ScalarType:
- newobj = asarray([key[k]])
- scalars.append(k)
- scalar = True
- else:
- newobj = key[k]
- objs.append(newobj)
- if isinstance(newobj, ndarray) and not scalar:
- if final_dtypedescr is None:
- final_dtypedescr = newobj.dtype
- elif newobj.dtype > final_dtypedescr:
- final_dtypedescr = newobj.dtype
- if final_dtypedescr is not None:
- for k in scalars:
- objs[k] = objs[k].astype(final_dtypedescr)
- res = concatenate(tuple(objs), axis=self.axis)
- return self._retval(res)
+
+ return super(MAxisConcatenator, self).__getitem__(key)
+
class mr_class(MAxisConcatenator):
"""
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index e7ebd8b82..4b7fe07b6 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -14,7 +14,8 @@ import itertools
import numpy as np
from numpy.testing import (
- TestCase, run_module_suite, assert_warns, suppress_warnings
+ TestCase, run_module_suite, assert_warns, suppress_warnings,
+ assert_raises
)
from numpy.ma.testutils import (
assert_, assert_array_equal, assert_equal, assert_almost_equal
@@ -304,6 +305,18 @@ class TestConcatenator(TestCase):
assert_array_equal(d[5:,:], b_2)
assert_array_equal(d.mask, np.r_[m_1, m_2])
+ def test_matrix_builder(self):
+ assert_raises(np.ma.MAError, lambda: mr_['1, 2; 3, 4'])
+
+ def test_matrix(self):
+ actual = mr_['r', 1, 2, 3]
+ expected = np.ma.array(np.r_['r', 1, 2, 3])
+ assert_array_equal(actual, expected)
+
+ # outer type is masked array, inner type is matrix
+ assert_equal(type(actual), type(expected))
+ assert_equal(type(actual.data), type(expected.data))
+
class TestNotMasked(TestCase):
# Tests notmasked_edges and notmasked_contiguous.