summaryrefslogtreecommitdiff
path: root/numpy/lib/stride_tricks.py
diff options
context:
space:
mode:
authorRobert Kern <robert.kern@gmail.com>2008-07-03 06:42:28 +0000
committerRobert Kern <robert.kern@gmail.com>2008-07-03 06:42:28 +0000
commitf912322ee454fee1d15da01c5fd951ab2fcb5f99 (patch)
tree9ef9a29f9573775cb3785a78b6838d9f4c3cbf29 /numpy/lib/stride_tricks.py
parenta74f0dfbcdfaf0ed5929fed7a27dc8738709828f (diff)
downloadnumpy-f912322ee454fee1d15da01c5fd951ab2fcb5f99.tar.gz
ENH: Add broadcast_arrays() function to expose broadcasting to pure Python functions that cannot be made to be ufuncs.
Diffstat (limited to 'numpy/lib/stride_tricks.py')
-rw-r--r--numpy/lib/stride_tricks.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
new file mode 100644
index 000000000..25987362f
--- /dev/null
+++ b/numpy/lib/stride_tricks.py
@@ -0,0 +1,109 @@
+""" Utilities that manipulate strides to achieve desirable effects.
+"""
+import numpy as np
+
+__all__ = ['broadcast_arrays']
+
+class DummyArray(object):
+ """ Dummy object that just exists to hang __array_interface__ dictionaries
+ and possibly keep alive a reference to a base array.
+ """
+ def __init__(self, interface, base=None):
+ self.__array_interface__ = interface
+ self.base = base
+
+def as_strided(x, shape=None, strides=None):
+ """ Make an ndarray from the given array with the given shape and strides.
+ """
+ interface = dict(x.__array_interface__)
+ if shape is not None:
+ interface['shape'] = tuple(shape)
+ if strides is not None:
+ interface['strides'] = tuple(strides)
+ return np.asarray(DummyArray(interface, base=x))
+
+def broadcast_arrays(*args):
+ """ Broadcast any number of arrays against each other.
+
+ Parameters
+ ----------
+ *args : arrays
+
+ Returns
+ -------
+ broadcasted : list of arrays
+ These arrays are views on the original arrays. They are typically not
+ contiguous. Furthermore, more than one element of a broadcasted array
+ may refer to a single memory location. If you need to write to the
+ arrays, make copies first.
+
+ Examples
+ --------
+ >>> x = np.array([[1,2,3]])
+ >>> y = np.array([[1],[2],[3]])
+ >>> np.broadcast_arrays(x, y)
+ [array([[1, 2, 3],
+ [1, 2, 3],
+ [1, 2, 3]]), array([[1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3]])]
+
+ Here is a useful idiom for getting contiguous copies instead of
+ non-contiguous views.
+
+ >>> map(np.array, np.broadcast_arrays(x, y))
+ [array([[1, 2, 3],
+ [1, 2, 3],
+ [1, 2, 3]]), array([[1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3]])]
+
+ """
+ args = map(np.asarray, args)
+ shapes = [x.shape for x in args]
+ if len(set(shapes)) == 1:
+ # Common case where nothing needs to be broadcasted.
+ return args
+ shapes = [list(s) for s in shapes]
+ strides = [list(x.strides) for x in args]
+ nds = [len(s) for s in shapes]
+ biggest = max(nds)
+ # Go through each array and prepend dimensions of length 1 to each of the
+ # shapes in order to make the number of dimensions equal.
+ for i in range(len(args)):
+ diff = biggest - nds[i]
+ if diff > 0:
+ shapes[i] = [1] * diff + shapes[i]
+ strides[i] = [0] * diff + strides[i]
+ # Chech each dimension for compatibility. A dimension length of 1 is
+ # accepted as compatible with any other length.
+ common_shape = []
+ for axis in range(biggest):
+ lengths = [s[axis] for s in shapes]
+ unique = set(lengths + [1])
+ if len(unique) > 2:
+ # There must be at least two non-1 lengths for this axis.
+ raise ValueError("shape mismatch: two or more arrays have "
+ "incompatible dimensions on axis %r." % (axis,))
+ elif len(unique) == 2:
+ # There is exactly one non-1 length. The common shape will take this
+ # value.
+ unique.remove(1)
+ new_length = unique.pop()
+ common_shape.append(new_length)
+ # For each array, if this axis is being broadcasted from a length of
+ # 1, then set its stride to 0 so that it repeats its data.
+ for i in range(len(args)):
+ if shapes[i][axis] == 1:
+ shapes[i][axis] = new_length
+ strides[i][axis] = 0
+ else:
+ # Every array has a length of 1 on this axis. Strides can be left
+ # alone as nothing is broadcasted.
+ common_shape.append(1)
+
+ # Construct the new arrays.
+ broadcasted = [as_strided(x, shape=sh, strides=st) for (x,sh,st) in
+ zip(args, shapes, strides)]
+ return broadcasted
+