diff options
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r-- | numpy/lib/utils.py | 53 |
1 files changed, 52 insertions, 1 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index d16dd8a78..ce2c0a593 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -7,7 +7,8 @@ from numpy.core import product, ndarray __all__ = ['issubclass_', 'get_numpy_include', 'issubsctype', 'issubdtype', 'deprecate', 'get_numarray_include', - 'get_include', 'info', 'source', 'who'] + 'get_include', 'info', 'source', 'who', + 'memory_bounds', 'may_share_memory'] def issubclass_(arg1, arg2): try: @@ -101,6 +102,56 @@ def deprecate(func, oldname, newname): get_numpy_include = deprecate(get_include, 'get_numpy_include', 'get_include') +#-------------------------------------------- +# Determine if two arrays can share memory +#-------------------------------------------- + +def memory_bounds(a): + """(low, high) are pointers to the end-points of an array + + low is the first byte + high is just *past* the last byte + + The array provided must conform to the Python-side of the array interface + """ + ai = a.__array_interface__ + a_data = ai['data'][0] + astrides = ai['strides'] + ashape = ai['shape'] + nd_a = len(ashape) + bytes_a = int(ai['typestr'][2:]) + + # a_low points to first element of array + # a_high points to last element of the array + + a_low = a_high = a_data + if astrides is None: # contiguous case + a_high += product(ashape, dtype=int)*bytes_a + else: + for shape, stride in zip(ashape, astrides): + if stride < 0: + a_low += (shape-1)*stride + else: + a_high += (shape-1)*stride + a_high += bytes_a + return a_low, a_high + + +def may_share_memory(a, b): + """Determine if two arrays can share memory + + The memory-bounds of a and b are computed. If they overlap then + this function returns True. Otherwise, it returns False. + + A return of True does not necessarily mean that the two arrays + share any element. It just means that they *might*. + """ + a_low, a_high = memory_bounds(a) + b_low, b_high = memory_bounds(b) + if b_low >= a_high or a_low >= b_high: + return False + return True + #----------------------------------------------------------------------------- # Function for output and information on the variables used. #----------------------------------------------------------------------------- |