diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2014-04-15 21:32:42 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2014-08-25 13:08:19 -0400 |
commit | 26a02cd702d9ccfc48978dcf81c80225f324bf3b (patch) | |
tree | 27e5a930f4809457c10dbabc100127c73b0aa2f0 /numpy/lib/stride_tricks.py | |
parent | 14e4cc3aa8bc6e01d6494860c7de6bf9ec0ab86b (diff) | |
download | numpy-26a02cd702d9ccfc48978dcf81c80225f324bf3b.tar.gz |
ENH: add subok flag to stride_tricks (and thus broadcast_arrays)
Diffstat (limited to 'numpy/lib/stride_tricks.py')
-rw-r--r-- | numpy/lib/stride_tricks.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index 12f8bbf13..c15bc5167 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -20,7 +20,7 @@ class DummyArray(object): self.__array_interface__ = interface self.base = base -def as_strided(x, shape=None, strides=None): +def as_strided(x, shape=None, strides=None, subok=False): """ Make an ndarray from the given array with the given shape and strides. """ interface = dict(x.__array_interface__) @@ -32,9 +32,15 @@ def as_strided(x, shape=None, strides=None): # Make sure dtype is correct in case of custom dtype if array.dtype.kind == 'V': array.dtype = x.dtype + if subok and isinstance(x, np.ndarray) and type(x) is not type(array): + array = array.view(type=type(x)) + # we have done something akin to a view from x, so we should let a + # possible subclass finalize (if it has it implemented) + if callable(getattr(array, '__array_finalize__', None)): + array.__array_finalize__(x) return array -def broadcast_arrays(*args): +def broadcast_arrays(*args, **kwargs): """ Broadcast any number of arrays against each other. @@ -43,6 +49,10 @@ def broadcast_arrays(*args): `*args` : array_likes The arrays to broadcast. + subok : bool, optional + If True, then sub-classes will be passed-through, otherwise + the returned arrays will be forced to be a base-class array (default). + Returns ------- broadcasted : list of arrays @@ -73,7 +83,11 @@ def broadcast_arrays(*args): [3, 3, 3]])] """ - args = [np.asarray(_m) for _m in args] + subok = kwargs.pop('subok', False) + if kwargs: + raise TypeError('broadcast_arrays() got an unexpected keyword ' + 'argument {}'.format(kwargs.pop())) + args = [np.array(_m, copy=False, subok=subok) for _m in args] shapes = [x.shape for x in args] if len(set(shapes)) == 1: # Common case where nothing needs to be broadcasted. @@ -118,6 +132,6 @@ def broadcast_arrays(*args): common_shape.append(1) # Construct the new arrays. - broadcasted = [as_strided(x, shape=sh, strides=st) for (x, sh, st) in - zip(args, shapes, strides)] + broadcasted = [as_strided(x, shape=sh, strides=st, subok=subok) + for (x, sh, st) in zip(args, shapes, strides)] return broadcasted |