diff options
Diffstat (limited to 'numpy/lib/stride_tricks.py')
-rw-r--r-- | numpy/lib/stride_tricks.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index 12f8bbf13..b81307a65 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -20,9 +20,11 @@ class DummyArray(object): self.__array_interface__ = interface self.base = base -def as_strided(x, shape=None, strides=None): +def as_strided(x, shape=None, strides=None, subok=False): """ Make an ndarray from the given array with the given shape and strides. """ + # first convert input to array, possibly keeping subclass + x = np.array(x, copy=False, subok=subok) interface = dict(x.__array_interface__) if shape is not None: interface['shape'] = tuple(shape) @@ -32,9 +34,17 @@ def as_strided(x, shape=None, strides=None): # Make sure dtype is correct in case of custom dtype if array.dtype.kind == 'V': array.dtype = x.dtype + if type(x) is not type(array): + # if input was an ndarray subclass and subclasses were OK, + # then view the result as that subclass. + array = array.view(type=type(x)) + # Since we have done something akin to a view from x, we should let + # the subclass finalize (if it has it implemented, i.e., is not None). + if array.__array_finalize__: + array.__array_finalize__(x) return array -def broadcast_arrays(*args): +def broadcast_arrays(*args, **kwargs): """ Broadcast any number of arrays against each other. @@ -43,6 +53,10 @@ def broadcast_arrays(*args): `*args` : array_likes The arrays to broadcast. + subok : bool, optional + If True, then sub-classes will be passed-through, otherwise + the returned arrays will be forced to be a base-class array (default). + Returns ------- broadcasted : list of arrays @@ -73,7 +87,11 @@ def broadcast_arrays(*args): [3, 3, 3]])] """ - args = [np.asarray(_m) for _m in args] + subok = kwargs.pop('subok', False) + if kwargs: + raise TypeError('broadcast_arrays() got an unexpected keyword ' + 'argument {}'.format(kwargs.pop())) + args = [np.array(_m, copy=False, subok=subok) for _m in args] shapes = [x.shape for x in args] if len(set(shapes)) == 1: # Common case where nothing needs to be broadcasted. @@ -118,6 +136,6 @@ def broadcast_arrays(*args): common_shape.append(1) # Construct the new arrays. - broadcasted = [as_strided(x, shape=sh, strides=st) for (x, sh, st) in - zip(args, shapes, strides)] + broadcasted = [as_strided(x, shape=sh, strides=st, subok=subok) + for (x, sh, st) in zip(args, shapes, strides)] return broadcasted |