summaryrefslogtreecommitdiff
path: root/numpy/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r--numpy/lib/utils.py53
1 files changed, 52 insertions, 1 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index d16dd8a78..ce2c0a593 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -7,7 +7,8 @@ from numpy.core import product, ndarray
__all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
'issubdtype', 'deprecate', 'get_numarray_include',
- 'get_include', 'info', 'source', 'who']
+ 'get_include', 'info', 'source', 'who',
+ 'memory_bounds', 'may_share_memory']
def issubclass_(arg1, arg2):
try:
@@ -101,6 +102,56 @@ def deprecate(func, oldname, newname):
get_numpy_include = deprecate(get_include, 'get_numpy_include', 'get_include')
+#--------------------------------------------
+# Determine if two arrays can share memory
+#--------------------------------------------
+
+def memory_bounds(a):
+ """(low, high) are pointers to the end-points of an array
+
+ low is the first byte
+ high is just *past* the last byte
+
+ The array provided must conform to the Python-side of the array interface
+ """
+ ai = a.__array_interface__
+ a_data = ai['data'][0]
+ astrides = ai['strides']
+ ashape = ai['shape']
+ nd_a = len(ashape)
+ bytes_a = int(ai['typestr'][2:])
+
+ # a_low points to first element of array
+ # a_high points to last element of the array
+
+ a_low = a_high = a_data
+ if astrides is None: # contiguous case
+ a_high += product(ashape, dtype=int)*bytes_a
+ else:
+ for shape, stride in zip(ashape, astrides):
+ if stride < 0:
+ a_low += (shape-1)*stride
+ else:
+ a_high += (shape-1)*stride
+ a_high += bytes_a
+ return a_low, a_high
+
+
+def may_share_memory(a, b):
+ """Determine if two arrays can share memory
+
+ The memory-bounds of a and b are computed. If they overlap then
+ this function returns True. Otherwise, it returns False.
+
+ A return of True does not necessarily mean that the two arrays
+ share any element. It just means that they *might*.
+ """
+ a_low, a_high = memory_bounds(a)
+ b_low, b_high = memory_bounds(b)
+ if b_low >= a_high or a_low >= b_high:
+ return False
+ return True
+
#-----------------------------------------------------------------------------
# Function for output and information on the variables used.
#-----------------------------------------------------------------------------