diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-12-01 16:11:44 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-12-01 16:11:44 -0800 |
commit | dc608ba2a85a7b89663bfd4d0ad977f40d0f813a (patch) | |
tree | 6c2ad8642036a19f38a31a465ef51b8e47e9a3fc /numpy/core/multiarray.py | |
parent | 12922e2d779aa2aa3b04b334743b5c7d97c12357 (diff) | |
parent | 8b70a86ba2d3ef33ab27c667e65e1ed94f110704 (diff) | |
download | numpy-dc608ba2a85a7b89663bfd4d0ad977f40d0f813a.tar.gz |
Merge pull request #12470 from shoyer/less-multiarray-indirection
MAINT: remove wrapper functions from numpy.core.multiarray
Diffstat (limited to 'numpy/core/multiarray.py')
-rw-r--r-- | numpy/core/multiarray.py | 306 |
1 files changed, 110 insertions, 196 deletions
diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py index 78963b0aa..083cca9b1 100644 --- a/numpy/core/multiarray.py +++ b/numpy/core/multiarray.py @@ -60,17 +60,19 @@ seterrobj.__module__ = 'numpy' zeros.__module__ = 'numpy' -array_function_dispatch = functools.partial( - overrides.array_function_dispatch, module='numpy') +# We can't verify dispatcher signatures because NumPy's C functions don't +# support introspection. +array_function_from_c_func_and_dispatcher = functools.partial( + overrides.array_function_from_dispatcher, + module='numpy', docs_from_dispatcher=True, verify=False) -def _empty_like_dispatcher(prototype, dtype=None, order=None, subok=None): - return (prototype,) - +@array_function_from_c_func_and_dispatcher(_multiarray_umath.empty_like) +def empty_like(prototype, dtype=None, order=None, subok=None): + """ + empty_like(prototype, dtype=None, order='K', subok=True) -@array_function_dispatch(_empty_like_dispatcher) -def empty_like(prototype, dtype=None, order='K', subok=True): - """Return a new array with the same shape and type as a given array. + Return a new array with the same shape and type as a given array. Parameters ---------- @@ -124,17 +126,11 @@ def empty_like(prototype, dtype=None, order='K', subok=True): [ 4.38791518e-305, -2.00000715e+000, 4.17269252e-309]]) """ - return _multiarray_umath.empty_like(prototype, dtype, order, subok) - - -def _concatenate_dispatcher(arrays, axis=None, out=None): - for array in arrays: - yield array - yield out + return (prototype,) -@array_function_dispatch(_concatenate_dispatcher) -def concatenate(arrays, axis=0, out=None): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.concatenate) +def concatenate(arrays, axis=None, out=None): """ concatenate((a1, a2, ...), axis=0, out=None) @@ -216,16 +212,16 @@ def concatenate(arrays, axis=0, out=None): fill_value=999999) """ - return _multiarray_umath.concatenate(arrays, axis, out) - - -def _inner_dispatcher(a, b): - return (a, b) + for array in arrays: + yield array + yield out -@array_function_dispatch(_inner_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.inner) def inner(a, b): """ + inner(a, b) + Inner product of two arrays. Ordinary inner product of vectors for 1-D arrays (without complex @@ -295,15 +291,11 @@ def inner(a, b): [ 0., 7.]]) """ - return _multiarray_umath.inner(a, b) - - -def _where_dispatcher(condition, x=None, y=None): - return (condition, x, y) + return (a, b) -@array_function_dispatch(_where_dispatcher) -def where(condition, x=np._NoValue, y=np._NoValue): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.where) +def where(condition, x=None, y=None): """ where(condition, [x, y]) @@ -374,21 +366,14 @@ def where(condition, x=np._NoValue, y=np._NoValue): [ 0, 2, -1], [ 0, 3, -1]]) """ - # _multiarray_umath.where only accepts positional arguments - args = tuple(a for a in (x, y) if a is not np._NoValue) - return _multiarray_umath.where(condition, *args) - - -def _lexsort_dispatcher(keys, axis=None): - if isinstance(keys, tuple): - return keys - else: - return (keys,) + return (condition, x, y) -@array_function_dispatch(_lexsort_dispatcher) -def lexsort(keys, axis=-1): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.lexsort) +def lexsort(keys, axis=None): """ + lexsort(keys, axis=-1) + Perform an indirect stable sort using a sequence of keys. Given multiple sorting keys, which can be interpreted as columns in a @@ -460,16 +445,17 @@ def lexsort(keys, axis=-1): array([2, 0, 4, 6, 5, 3, 1]) """ - return _multiarray_umath.lexsort(keys, axis) - - -def _can_cast_dispatcher(from_, to, casting=None): - return (from_,) + if isinstance(keys, tuple): + return keys + else: + return (keys,) -@array_function_dispatch(_can_cast_dispatcher) -def can_cast(from_, to, casting='safe'): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.can_cast) +def can_cast(from_, to, casting=None): """ + can_cast(from_, to, casting='safe') + Returns True if cast between data types can occur according to the casting rule. If from is a scalar or array scalar, also returns True if the scalar value can be cast without overflow or truncation @@ -573,16 +559,14 @@ def can_cast(from_, to, casting='safe'): True """ - return _multiarray_umath.can_cast(from_, to, casting) - - -def _min_scalar_type_dispatcher(a): - return (a,) + return (from_,) -@array_function_dispatch(_min_scalar_type_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.min_scalar_type) def min_scalar_type(a): """ + min_scalar_type(a) + For scalar ``a``, returns the data type with the smallest size and smallest scalar kind which can hold its value. For non-scalar array ``a``, returns the vector's dtype unmodified. @@ -626,16 +610,14 @@ def min_scalar_type(a): dtype('float64') """ - return _multiarray_umath.min_scalar_type(a) - - -def _result_type_dispatcher(*arrays_and_dtypes): - return arrays_and_dtypes + return (a,) -@array_function_dispatch(_result_type_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.result_type) def result_type(*arrays_and_dtypes): """ + result_type(*arrays_and_dtypes) + Returns the type that results from applying the NumPy type promotion rules to the arguments. @@ -700,16 +682,14 @@ def result_type(*arrays_and_dtypes): dtype('float64') """ - return _multiarray_umath.result_type(*arrays_and_dtypes) - - -def _dot_dispatcher(a, b, out=None): - return (a, b, out) + return arrays_and_dtypes -@array_function_dispatch(_dot_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.dot) def dot(a, b, out=None): """ + dot(a, b, out=None) + Dot product of two arrays. Specifically, - If both `a` and `b` are 1-D arrays, it is inner product of vectors @@ -790,16 +770,14 @@ def dot(a, b, out=None): 499128 """ - return _multiarray_umath.dot(a, b, out) - - -def _vdot_dispatcher(a, b): - return (a, b) + return (a, b, out) -@array_function_dispatch(_vdot_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.vdot) def vdot(a, b): """ + vdot(a, b) + Return the dot product of two vectors. The vdot(`a`, `b`) function handles complex numbers differently than @@ -850,16 +828,14 @@ def vdot(a, b): 30 """ - return _multiarray_umath.vdot(a, b) - - -def _bincount_dispatcher(x, weights=None, minlength=None): - return (x, weights) + return (a, b) -@array_function_dispatch(_bincount_dispatcher) -def bincount(x, weights=None, minlength=0): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.bincount) +def bincount(x, weights=None, minlength=None): """ + bincount(x, weights=None, minlength=0) + Count number of occurrences of each value in array of non-negative ints. The number of bins (of size 1) is one larger than the largest value in @@ -928,16 +904,14 @@ def bincount(x, weights=None, minlength=0): array([ 0.3, 0.7, 1.1]) """ - return _multiarray_umath.bincount(x, weights=weights, minlength=minlength) - - -def _ravel_multi_index_dispatcher(multi_index, dims, mode=None, order=None): - return multi_index + return (x, weights) -@array_function_dispatch(_ravel_multi_index_dispatcher) -def ravel_multi_index(multi_index, dims, mode='raise', order='C'): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.ravel_multi_index) +def ravel_multi_index(multi_index, dims, mode=None, order=None): """ + ravel_multi_index(multi_index, dims, mode='raise', order='C') + Converts a tuple of index arrays into an array of flat indices, applying boundary modes to the multi-index. @@ -991,26 +965,14 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): >>> np.ravel_multi_index((3,1,4,1), (6,7,8,9)) 1621 """ - return _multiarray_umath.ravel_multi_index( - multi_index, dims, mode=mode, order=order) - - -def _deprecate_dims(shape, dims): - if dims is not None: - warnings.warn("'shape' argument should be used instead of 'dims'", - DeprecationWarning, stacklevel=3) - shape = dims - return shape - - -def _unravel_index_dispatcher(indices, shape=None, order=None, dims=None): - shape = _deprecate_dims(shape, dims) - return (indices,) + return multi_index -@array_function_dispatch(_unravel_index_dispatcher) -def unravel_index(indices, shape=None, order='C', dims=None): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.unravel_index) +def unravel_index(indices, shape=None, order=None, dims=None): """ + unravel_index(indices, shape, order='C') + Converts a flat index or array of flat indices into a tuple of coordinate arrays. @@ -1053,17 +1015,17 @@ def unravel_index(indices, shape=None, order='C', dims=None): (3, 1, 4, 1) """ - shape = _deprecate_dims(shape, dims) - return _multiarray_umath.unravel_index(indices, shape, order=order) - - -def _copyto_dispatcher(dst, src, casting=None, where=None): - return (dst, src, where) + if dims is not None: + warnings.warn("'shape' argument should be used instead of 'dims'", + DeprecationWarning, stacklevel=3) + return (indices,) -@array_function_dispatch(_copyto_dispatcher) -def copyto(dst, src, casting='same_kind', where=True): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.copyto) +def copyto(dst, src, casting=None, where=None): """ + copyto(dst, src, casting='same_kind', where=True) + Copies values from one array to another, broadcasting as necessary. Raises a TypeError if the `casting` rule is violated, and if @@ -1091,16 +1053,14 @@ def copyto(dst, src, casting='same_kind', where=True): of `dst`, and selects elements to copy from `src` to `dst` wherever it contains the value True. """ - return _multiarray_umath.copyto(dst, src, casting=casting, where=where) - - -def _putmask_dispatcher(a, mask, values): - return (a, mask, values) + return (dst, src, where) -@array_function_dispatch(_putmask_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.putmask) def putmask(a, mask, values): """ + putmask(a, mask, values) + Changes elements of an array based on conditional and input values. Sets ``a.flat[n] = values[n]`` for each n where ``mask.flat[n]==True``. @@ -1138,16 +1098,14 @@ def putmask(a, mask, values): array([ 0, 1, -33, -44, -33]) """ - return _multiarray_umath.putmask(a, mask, values) - - -def _packbits_and_unpackbits_dispatcher(myarray, axis=None): - return (myarray,) + return (a, mask, values) -@array_function_dispatch(_packbits_and_unpackbits_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.packbits) def packbits(myarray, axis=None): """ + packbits(myarray, axis=None) + Packs the elements of a binary-valued array into bits in a uint8 array. The result is padded to full bytes by inserting zero bits at the end. @@ -1188,12 +1146,14 @@ def packbits(myarray, axis=None): and 32 = 0010 0000. """ - return _multiarray_umath.packbits(myarray, axis) + return (myarray,) -@array_function_dispatch(_packbits_and_unpackbits_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.unpackbits) def unpackbits(myarray, axis=None): """ + unpackbits(myarray, axis=None) + Unpacks elements of a uint8 array into a binary-valued output array. Each element of `myarray` represents a bit-field that should be unpacked @@ -1233,16 +1193,14 @@ def unpackbits(myarray, axis=None): [0, 0, 0, 1, 0, 1, 1, 1]], dtype=uint8) """ - return _multiarray_umath.unpackbits(myarray, axis) - - -def _shares_memory_dispatcher(a, b, max_work=None): - return (a, b) + return (myarray,) -@array_function_dispatch(_shares_memory_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.shares_memory) def shares_memory(a, b, max_work=None): """ + shares_memory(a, b, max_work=None) + Determine if two arrays share memory Parameters @@ -1279,12 +1237,14 @@ def shares_memory(a, b, max_work=None): False """ - return _multiarray_umath.shares_memory(a, b, max_work=max_work) + return (a, b) -@array_function_dispatch(_shares_memory_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.may_share_memory) def may_share_memory(a, b, max_work=None): """ + may_share_memory(a, b, max_work=None) + Determine if two arrays might share memory A return of True does not necessarily mean that the two arrays @@ -1318,17 +1278,11 @@ def may_share_memory(a, b, max_work=None): True """ - return _multiarray_umath.may_share_memory(a, b, max_work=max_work) - - -def _is_busday_dispatcher( - dates, weekmask=None, holidays=None, busdaycal=None, out=None): - return (dates, weekmask, holidays, out) + return (a, b) -@array_function_dispatch(_is_busday_dispatcher) -def is_busday(dates, weekmask=None, holidays=None, busdaycal=None, - out=None): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.is_busday) +def is_busday(dates, weekmask=None, holidays=None, busdaycal=None, out=None): """ is_busday(dates, weekmask='1111100', holidays=None, busdaycal=None, out=None) @@ -1378,26 +1332,12 @@ def is_busday(dates, weekmask=None, holidays=None, busdaycal=None, ... holidays=['2011-07-01', '2011-07-04', '2011-07-17']) array([False, False, True], dtype='bool') """ - kwargs = {} - if weekmask is not None: - kwargs['weekmask'] = weekmask - if holidays is not None: - kwargs['holidays'] = holidays - if busdaycal is not None: - kwargs['busdaycal'] = busdaycal - if out is not None: - kwargs['out'] = out - return _multiarray_umath.is_busday(dates, **kwargs) - - -def _busday_offset_dispatcher(dates, offsets, roll=None, weekmask=None, - holidays=None, busdaycal=None, out=None): - return (dates, offsets, weekmask, holidays, out) + return (dates, weekmask, holidays, out) -@array_function_dispatch(_busday_offset_dispatcher) -def busday_offset(dates, offsets, roll='raise', weekmask=None, - holidays=None, busdaycal=None, out=None): +@array_function_from_c_func_and_dispatcher(_multiarray_umath.busday_offset) +def busday_offset(dates, offsets, roll=None, weekmask=None, holidays=None, + busdaycal=None, out=None): """ busday_offset(dates, offsets, roll='raise', weekmask='1111100', holidays=None, busdaycal=None, out=None) @@ -1486,24 +1426,10 @@ def busday_offset(dates, offsets, roll='raise', weekmask=None, >>> np.busday_offset('2011-03-22', 1, roll='backward') numpy.datetime64('2011-03-23','D') """ - kwargs = {} - if weekmask is not None: - kwargs['weekmask'] = weekmask - if holidays is not None: - kwargs['holidays'] = holidays - if busdaycal is not None: - kwargs['busdaycal'] = busdaycal - if out is not None: - kwargs['out'] = out - return _multiarray_umath.busday_offset(dates, offsets, roll, **kwargs) - - -def _busday_count_dispatcher(begindates, enddates, weekmask=None, - holidays=None, busdaycal=None, out=None): - return (begindates, enddates, weekmask, holidays, out) + return (dates, offsets, weekmask, holidays, out) -@array_function_dispatch(_busday_count_dispatcher) +@array_function_from_c_func_and_dispatcher(_multiarray_umath.busday_count) def busday_count(begindates, enddates, weekmask=None, holidays=None, busdaycal=None, out=None): """ @@ -1568,26 +1494,15 @@ def busday_count(begindates, enddates, weekmask=None, holidays=None, ... np.busday_count('2011', '2012', weekmask='Sat') 53 """ - kwargs = {} - if weekmask is not None: - kwargs['weekmask'] = weekmask - if holidays is not None: - kwargs['holidays'] = holidays - if busdaycal is not None: - kwargs['busdaycal'] = busdaycal - if out is not None: - kwargs['out'] = out - return _multiarray_umath.busday_count(begindates, enddates, **kwargs) - - -def _datetime_as_string_dispatcher( - arr, unit=None, timezone=None, casting=None): - return (arr,) + return (begindates, enddates, weekmask, holidays, out) -@array_function_dispatch(_datetime_as_string_dispatcher) -def datetime_as_string(arr, unit=None, timezone='naive', casting='same_kind'): +@array_function_from_c_func_and_dispatcher( + _multiarray_umath.datetime_as_string) +def datetime_as_string(arr, unit=None, timezone=None, casting=None): """ + datetime_as_string(arr, unit=None, timezone='naive', casting='same_kind') + Convert an array of datetimes into an array of strings. Parameters @@ -1644,5 +1559,4 @@ def datetime_as_string(arr, unit=None, timezone='naive', casting='same_kind'): TypeError: Cannot create a datetime string as units 'h' from a NumPy datetime with units 'm' according to the rule 'safe' """ - return _multiarray_umath.datetime_as_string(arr, unit, timezone, casting) - + return (arr,) |