summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py84
1 files changed, 41 insertions, 43 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 1fd530f33..dc8eb1c4a 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -10,13 +10,11 @@ from numpy.core.numeric import (
from numpy.core.numerictypes import find_common_type, issubdtype
from . import function_base
-import numpy.matrixlib as matrix
+import numpy.matrixlib as matrixlib
from .function_base import diff
from numpy.core.multiarray import ravel_multi_index, unravel_index
from numpy.lib.stride_tricks import as_strided
-makemat = matrix.matrix
-
__all__ = [
'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
@@ -235,48 +233,44 @@ class AxisConcatenator(object):
Translates slice objects to concatenation along an axis.
For detailed documentation on usage, see `r_`.
-
"""
-
- def _retval(self, res):
- if self.matrix:
- oldndim = res.ndim
- res = makemat(res)
- if oldndim == 1 and self.col:
- res = res.T
- self.axis = self._axis
- self.matrix = self._matrix
- self.col = 0
- return res
+ # allow ma.mr_ to override this
+ concatenate = staticmethod(_nx.concatenate)
+ makemat = staticmethod(matrixlib.matrix)
def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1):
- self._axis = axis
- self._matrix = matrix
self.axis = axis
self.matrix = matrix
- self.col = 0
self.trans1d = trans1d
self.ndmin = ndmin
def __getitem__(self, key):
- trans1d = self.trans1d
- ndmin = self.ndmin
+ # handle matrix builder syntax
if isinstance(key, str):
frame = sys._getframe().f_back
- mymat = matrix.bmat(key, frame.f_globals, frame.f_locals)
+ mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals)
return mymat
+
if not isinstance(key, tuple):
key = (key,)
+
+ # copy attributes, since they can be overriden in the first argument
+ trans1d = self.trans1d
+ ndmin = self.ndmin
+ matrix = self.matrix
+ axis = self.axis
+
objs = []
scalars = []
arraytypes = []
scalartypes = []
- for k in range(len(key)):
+
+ for k, item in enumerate(key):
scalar = False
- if isinstance(key[k], slice):
- step = key[k].step
- start = key[k].start
- stop = key[k].stop
+ if isinstance(item, slice):
+ step = item.step
+ start = item.start
+ stop = item.stop
if start is None:
start = 0
if step is None:
@@ -290,37 +284,35 @@ class AxisConcatenator(object):
newobj = array(newobj, copy=False, ndmin=ndmin)
if trans1d != -1:
newobj = newobj.swapaxes(-1, trans1d)
- elif isinstance(key[k], str):
+ elif isinstance(item, str):
if k != 0:
raise ValueError("special directives must be the "
"first entry.")
- key0 = key[0]
- if key0 in 'rc':
- self.matrix = True
- self.col = (key0 == 'c')
+ if item in ('r', 'c'):
+ matrix = True
+ col = (item == 'c')
continue
- if ',' in key0:
- vec = key0.split(',')
+ if ',' in item:
+ vec = item.split(',')
try:
- self.axis, ndmin = \
- [int(x) for x in vec[:2]]
+ axis, ndmin = [int(x) for x in vec[:2]]
if len(vec) == 3:
trans1d = int(vec[2])
continue
except:
raise ValueError("unknown special directive")
try:
- self.axis = int(key[k])
+ axis = int(item)
continue
except (ValueError, TypeError):
raise ValueError("unknown special directive")
- elif type(key[k]) in ScalarType:
- newobj = array(key[k], ndmin=ndmin)
- scalars.append(k)
+ elif type(item) in ScalarType:
+ newobj = array(item, ndmin=ndmin)
+ scalars.append(len(objs))
scalar = True
scalartypes.append(newobj.dtype)
else:
- newobj = key[k]
+ newobj = item
if ndmin > 1:
tempobj = array(newobj, copy=False, subok=True)
newobj = array(newobj, copy=False, subok=True,
@@ -339,14 +331,20 @@ class AxisConcatenator(object):
if not scalar and isinstance(newobj, _nx.ndarray):
arraytypes.append(newobj.dtype)
- # Esure that scalars won't up-cast unless warranted
+ # Ensure that scalars won't up-cast unless warranted
final_dtype = find_common_type(arraytypes, scalartypes)
if final_dtype is not None:
for k in scalars:
objs[k] = objs[k].astype(final_dtype)
- res = _nx.concatenate(tuple(objs), axis=self.axis)
- return self._retval(res)
+ res = self.concatenate(tuple(objs), axis=axis)
+
+ if matrix:
+ oldndim = res.ndim
+ res = self.makemat(res)
+ if oldndim == 1 and col:
+ res = res.T
+ return res
def __len__(self):
return 0