summaryrefslogtreecommitdiff
path: root/numpy/lib/stride_tricks.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2014-04-15 21:32:42 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2014-08-25 13:08:19 -0400
commit26a02cd702d9ccfc48978dcf81c80225f324bf3b (patch)
tree27e5a930f4809457c10dbabc100127c73b0aa2f0 /numpy/lib/stride_tricks.py
parent14e4cc3aa8bc6e01d6494860c7de6bf9ec0ab86b (diff)
downloadnumpy-26a02cd702d9ccfc48978dcf81c80225f324bf3b.tar.gz
ENH: add subok flag to stride_tricks (and thus broadcast_arrays)
Diffstat (limited to 'numpy/lib/stride_tricks.py')
-rw-r--r--numpy/lib/stride_tricks.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index 12f8bbf13..c15bc5167 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -20,7 +20,7 @@ class DummyArray(object):
self.__array_interface__ = interface
self.base = base
-def as_strided(x, shape=None, strides=None):
+def as_strided(x, shape=None, strides=None, subok=False):
""" Make an ndarray from the given array with the given shape and strides.
"""
interface = dict(x.__array_interface__)
@@ -32,9 +32,15 @@ def as_strided(x, shape=None, strides=None):
# Make sure dtype is correct in case of custom dtype
if array.dtype.kind == 'V':
array.dtype = x.dtype
+ if subok and isinstance(x, np.ndarray) and type(x) is not type(array):
+ array = array.view(type=type(x))
+ # we have done something akin to a view from x, so we should let a
+ # possible subclass finalize (if it has it implemented)
+ if callable(getattr(array, '__array_finalize__', None)):
+ array.__array_finalize__(x)
return array
-def broadcast_arrays(*args):
+def broadcast_arrays(*args, **kwargs):
"""
Broadcast any number of arrays against each other.
@@ -43,6 +49,10 @@ def broadcast_arrays(*args):
`*args` : array_likes
The arrays to broadcast.
+ subok : bool, optional
+ If True, then sub-classes will be passed-through, otherwise
+ the returned arrays will be forced to be a base-class array (default).
+
Returns
-------
broadcasted : list of arrays
@@ -73,7 +83,11 @@ def broadcast_arrays(*args):
[3, 3, 3]])]
"""
- args = [np.asarray(_m) for _m in args]
+ subok = kwargs.pop('subok', False)
+ if kwargs:
+ raise TypeError('broadcast_arrays() got an unexpected keyword '
+ 'argument {}'.format(kwargs.pop()))
+ args = [np.array(_m, copy=False, subok=subok) for _m in args]
shapes = [x.shape for x in args]
if len(set(shapes)) == 1:
# Common case where nothing needs to be broadcasted.
@@ -118,6 +132,6 @@ def broadcast_arrays(*args):
common_shape.append(1)
# Construct the new arrays.
- broadcasted = [as_strided(x, shape=sh, strides=st) for (x, sh, st) in
- zip(args, shapes, strides)]
+ broadcasted = [as_strided(x, shape=sh, strides=st, subok=subok)
+ for (x, sh, st) in zip(args, shapes, strides)]
return broadcasted