diff options
Diffstat (limited to 'numpy')
24 files changed, 560 insertions, 49 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index f10ce9f0f..2ce2fdb55 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -843,7 +843,7 @@ defdict = { Ufunc(1, 1, None, docstrings.get('numpy.core.umath.isnan'), 'PyUFunc_IsFiniteTypeResolver', - TD(noobj, out='?'), + TD(noobj, simd=[('avx512_skx', 'fd')], out='?'), ), 'isnat': Ufunc(1, 1, None, @@ -855,19 +855,19 @@ defdict = { Ufunc(1, 1, None, docstrings.get('numpy.core.umath.isinf'), 'PyUFunc_IsFiniteTypeResolver', - TD(noobj, out='?'), + TD(noobj, simd=[('avx512_skx', 'fd')], out='?'), ), 'isfinite': Ufunc(1, 1, None, docstrings.get('numpy.core.umath.isfinite'), 'PyUFunc_IsFiniteTypeResolver', - TD(noobj, out='?'), + TD(noobj, simd=[('avx512_skx', 'fd')], out='?'), ), 'signbit': Ufunc(1, 1, None, docstrings.get('numpy.core.umath.signbit'), None, - TD(flts, out='?'), + TD(flts, simd=[('avx512_skx', 'fd')], out='?'), ), 'copysign': Ufunc(2, 1, None, @@ -898,10 +898,10 @@ defdict = { docstrings.get('numpy.core.umath.ldexp'), None, [TypeDescription('e', None, 'ei', 'e'), - TypeDescription('f', None, 'fi', 'f'), + TypeDescription('f', None, 'fi', 'f', simd=['avx512_skx']), TypeDescription('e', FuncNameSuffix('long'), 'el', 'e'), TypeDescription('f', FuncNameSuffix('long'), 'fl', 'f'), - TypeDescription('d', None, 'di', 'd'), + TypeDescription('d', None, 'di', 'd', simd=['avx512_skx']), TypeDescription('d', FuncNameSuffix('long'), 'dl', 'd'), TypeDescription('g', None, 'gi', 'g'), TypeDescription('g', FuncNameSuffix('long'), 'gl', 'g'), @@ -912,8 +912,8 @@ defdict = { docstrings.get('numpy.core.umath.frexp'), None, [TypeDescription('e', None, 'e', 'ei'), - TypeDescription('f', None, 'f', 'fi'), - TypeDescription('d', None, 'd', 'di'), + TypeDescription('f', None, 'f', 'fi', simd=['avx512_skx']), + TypeDescription('d', None, 'd', 'di', simd=['avx512_skx']), TypeDescription('g', None, 'g', 'gi'), ], ), diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h index 5dd62e64a..1b61899fa 100644 --- a/numpy/core/include/numpy/ndarraytypes.h +++ b/numpy/core/include/numpy/ndarraytypes.h @@ -1821,7 +1821,7 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size, * may change without warning! */ /* TODO: Make this definition public in the API, as soon as its settled */ - NPY_NO_EXPORT PyTypeObject PyArrayDTypeMeta_Type; + NPY_NO_EXPORT extern PyTypeObject PyArrayDTypeMeta_Type; /* * While NumPy DTypes would not need to be heap types the plan is to diff --git a/numpy/core/include/numpy/npy_1_7_deprecated_api.h b/numpy/core/include/numpy/npy_1_7_deprecated_api.h index 440458010..a4f90e019 100644 --- a/numpy/core/include/numpy/npy_1_7_deprecated_api.h +++ b/numpy/core/include/numpy/npy_1_7_deprecated_api.h @@ -13,11 +13,10 @@ #define _WARN___LOC__ __FILE__ "(" _WARN___STR1__(__LINE__) ") : Warning Msg: " #pragma message(_WARN___LOC__"Using deprecated NumPy API, disable it with " \ "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION") -#elif defined(__GNUC__) +#else #warning "Using deprecated NumPy API, disable it with " \ "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" #endif -/* TODO: How to do this warning message for other compilers? */ #endif /* diff --git a/numpy/core/include/numpy/npy_3kcompat.h b/numpy/core/include/numpy/npy_3kcompat.h index efe196c84..798da6957 100644 --- a/numpy/core/include/numpy/npy_3kcompat.h +++ b/numpy/core/include/numpy/npy_3kcompat.h @@ -60,6 +60,14 @@ static NPY_INLINE int PyInt_Check(PyObject *op) { PySlice_GetIndicesEx((PySliceObject *)op, nop, start, end, step, slicelength) #endif +#if PY_VERSION_HEX < 0x030900a4 + /* Introduced in https://github.com/python/cpython/commit/d2ec81a8c99796b51fb8c49b77a7fe369863226f */ + #define Py_SET_TYPE(obj, typ) (Py_TYPE(obj) = typ) + /* Introduced in https://github.com/python/cpython/commit/b10dc3e7a11fcdb97e285882eba6da92594f90f9 */ + #define Py_SET_SIZE(obj, size) (Py_SIZE(obj) = size) +#endif + + #define Npy_EnterRecursiveCall(x) Py_EnterRecursiveCall(x) /* Py_SETREF was added in 3.5.2, and only if Py_LIMITED_API is absent */ @@ -546,4 +554,5 @@ NpyCapsule_Check(PyObject *ptr) } #endif + #endif /* _NPY_3KCOMPAT_H_ */ diff --git a/numpy/core/include/numpy/npy_common.h b/numpy/core/include/numpy/npy_common.h index c2e755958..3cec0c6ff 100644 --- a/numpy/core/include/numpy/npy_common.h +++ b/numpy/core/include/numpy/npy_common.h @@ -64,6 +64,13 @@ #define NPY_GCC_TARGET_AVX512F #endif +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX && defined HAVE_LINK_AVX512_SKX +#define NPY_GCC_TARGET_AVX512_SKX __attribute__((target("avx512f,avx512dq,avx512vl,avx512bw,avx512cd"))) +#elif defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS +#define NPY_GCC_TARGET_AVX512_SKX __attribute__((target("avx512f,avx512dq,avx512vl,avx512bw,avx512cd"))) +#else +#define NPY_GCC_TARGET_AVX512_SKX +#endif /* * mark an argument (starting from 1) that must not be NULL and is not checked * DO NOT USE IF FUNCTION CHECKS FOR NULL!! the compiler will remove the check diff --git a/numpy/core/setup_common.py b/numpy/core/setup_common.py index 72b59f9ae..8c0149497 100644 --- a/numpy/core/setup_common.py +++ b/numpy/core/setup_common.py @@ -147,6 +147,10 @@ OPTIONAL_INTRINSICS = [("__builtin_isnan", '5.'), "stdio.h", "LINK_AVX2"), ("__asm__ volatile", '"vpaddd %zmm1, %zmm2, %zmm3"', "stdio.h", "LINK_AVX512F"), + ("__asm__ volatile", '"vfpclasspd $0x40, %zmm15, %k6\\n"\ + "vmovdqu8 %xmm0, %xmm1\\n"\ + "vpbroadcastmb2q %k0, %xmm0\\n"', + "stdio.h", "LINK_AVX512_SKX"), ("__asm__ volatile", '"xgetbv"', "stdio.h", "XGETBV"), ] @@ -165,6 +169,8 @@ OPTIONAL_FUNCTION_ATTRIBUTES = [('__attribute__((optimize("unroll-loops")))', 'attribute_target_avx2'), ('__attribute__((target ("avx512f")))', 'attribute_target_avx512f'), + ('__attribute__((target ("avx512f,avx512dq,avx512bw,avx512vl,avx512cd")))', + 'attribute_target_avx512_skx'), ] # function attributes with intrinsics @@ -181,6 +187,11 @@ OPTIONAL_FUNCTION_ATTRIBUTES_WITH_INTRINSICS = [('__attribute__((target("avx2,fm 'attribute_target_avx512f_with_intrinsics', '__m512 temp = _mm512_set1_ps(1.0)', 'immintrin.h'), + ('__attribute__((target ("avx512f,avx512dq,avx512bw,avx512vl,avx512cd")))', + 'attribute_target_avx512_skx_with_intrinsics', + '__mmask8 temp = _mm512_fpclass_pd_mask(_mm512_set1_pd(1.0), 0x01);\ + _mm_mask_storeu_epi8(NULL, 0xFF, _mm_broadcastmb_epi64(temp))', + 'immintrin.h'), ] # variable attributes tested via "int %s a" % attribute diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src index 318559885..da631c830 100644 --- a/numpy/core/src/multiarray/_multiarray_tests.c.src +++ b/numpy/core/src/multiarray/_multiarray_tests.c.src @@ -2045,6 +2045,21 @@ run_casting_converter(PyObject* NPY_UNUSED(self), PyObject *args) return PyInt_FromLong(casting); } +static PyObject * +run_intp_converter(PyObject* NPY_UNUSED(self), PyObject *args) +{ + PyArray_Dims dims = {NULL, -1}; + if (!PyArg_ParseTuple(args, "O&", PyArray_IntpConverter, &dims)) { + return NULL; + } + if (dims.len == -1) { + Py_RETURN_NONE; + } + + PyObject *tup = PyArray_IntTupleFromIntp(dims.len, dims.ptr); + PyDimMem_FREE(dims.ptr); + return tup; +} static PyMethodDef Multiarray_TestsMethods[] = { {"IsPythonScalar", @@ -2218,6 +2233,9 @@ static PyMethodDef Multiarray_TestsMethods[] = { {"run_casting_converter", run_casting_converter, METH_VARARGS, NULL}, + {"run_intp_converter", + run_intp_converter, + METH_VARARGS, NULL}, {NULL, NULL, 0, NULL} /* Sentinel */ }; diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 552c56349..2048b5898 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -27,6 +27,9 @@ #include "arrayobject.h" #include "alloc.h" #include "typeinfo.h" +#if defined(__ARM_NEON__) || defined (__ARM_NEON) +#include <arm_neon.h> +#endif #ifdef NPY_HAVE_SSE2_INTRINSICS #include <emmintrin.h> #endif @@ -3070,7 +3073,15 @@ finish: ** ARGFUNC ** ***************************************************************************** */ - +#if defined(__ARM_NEON__) || defined (__ARM_NEON) + int32_t _mm_movemask_epi8_neon(uint8x16_t input) + { + int8x8_t m0 = vcreate_s8(0x0706050403020100ULL); + uint8x16_t v0 = vshlq_u8(vshrq_n_u8(input, 7), vcombine_s8(m0, m0)); + uint64x2_t v1 = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v0))); + return (int)vgetq_lane_u64(v1, 0) + ((int)vgetq_lane_u64(v1, 1) << 8); + } +#endif #define _LESS_THAN_OR_EQUAL(a,b) ((a) <= (b)) static int @@ -3091,6 +3102,19 @@ BOOL_argmax(npy_bool *ip, npy_intp n, npy_intp *max_ind, break; } } +#else + #if defined(__ARM_NEON__) || defined (__ARM_NEON) + uint8x16_t zero = vdupq_n_u8(0); + for(; i < n - (n % 32); i+=32) { + uint8x16_t d1 = vld1q_u8((char *)&ip[i]); + uint8x16_t d2 = vld1q_u8((char *)&ip[i + 16]); + d1 = vceqq_u8(d1, zero); + d2 = vceqq_u8(d2, zero); + if(_mm_movemask_epi8_neon(vminq_u8(d1, d2)) != 0xFFFF) { + break; + } + } + #endif #endif for (; i < n; i++) { if (ip[i]) { diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c index 14d546867..ac698d680 100644 --- a/numpy/core/src/multiarray/conversion_utils.c +++ b/numpy/core/src/multiarray/conversion_utils.c @@ -95,9 +95,21 @@ PyArray_IntpConverter(PyObject *obj, PyArray_Dims *seq) seq->ptr = NULL; seq->len = 0; + + /* + * When the deprecation below expires, remove the `if` statement, and + * update the comment for PyArray_OptionalIntpConverter. + */ if (obj == Py_None) { + /* Numpy 1.20, 2020-05-31 */ + if (DEPRECATE( + "Passing None into shape arguments as an alias for () is " + "deprecated.") < 0){ + return NPY_FAIL; + } return NPY_SUCCEED; } + len = PySequence_Size(obj); if (len == -1) { /* Check to see if it is an integer number */ diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index 76f7b599a..9982cd676 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -230,7 +230,7 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) } /* Finally, replace the current class of the descr */ - Py_TYPE(descr) = (PyTypeObject *)dtype_class; + Py_SET_TYPE(descr, (PyTypeObject *)dtype_class); return 0; } @@ -266,4 +266,3 @@ NPY_NO_EXPORT PyTypeObject PyArrayDTypeMeta_Type = { .tp_is_gc = dtypemeta_is_gc, .tp_traverse = (traverseproc)dtypemeta_traverse, }; - diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index ab5076711..84c22ba65 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -2498,9 +2498,9 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) "subscript is not within the valid range [0, 52)"); Py_DECREF(obj); return -1; - } + } } - + } Py_DECREF(obj); @@ -4453,7 +4453,7 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) { } PyArrayDescr_Type.tp_hash = PyArray_DescrHash; - Py_TYPE(&PyArrayDescr_Type) = &PyArrayDTypeMeta_Type; + Py_SET_TYPE(&PyArrayDescr_Type, &PyArrayDTypeMeta_Type); if (PyType_Ready(&PyArrayDescr_Type) < 0) { goto err; } diff --git a/numpy/core/src/multiarray/scalarapi.c b/numpy/core/src/multiarray/scalarapi.c index 8a7139fb2..f3c440dc6 100644 --- a/numpy/core/src/multiarray/scalarapi.c +++ b/numpy/core/src/multiarray/scalarapi.c @@ -755,7 +755,7 @@ PyArray_Scalar(void *data, PyArray_Descr *descr, PyObject *base) vobj->descr = descr; Py_INCREF(descr); vobj->obval = NULL; - Py_SIZE(vobj) = itemsize; + Py_SET_SIZE(vobj, itemsize); vobj->flags = NPY_ARRAY_CARRAY | NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_OWNDATA; swap = 0; if (PyDataType_HASFIELDS(descr)) { diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index f13f50759..a7c3e847a 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -345,7 +345,7 @@ format_@name@(@type@ val, npy_bool scientific, * over-ride repr and str of array-scalar strings and unicode to * remove NULL bytes and then call the corresponding functions * of string and unicode. - * + * * FIXME: * is this really a good idea? * stop using Py_UNICODE here. @@ -1542,7 +1542,7 @@ static PyObject * return NULL; } #endif - + PyObject *tup; if (ndigits == Py_None) { tup = PyTuple_Pack(0); @@ -1568,7 +1568,7 @@ static PyObject * return ret; } #endif - + return obj; } /**end repeat**/ @@ -2774,7 +2774,7 @@ void_arrtype_new(PyTypeObject *type, PyObject *args, PyObject *kwds) return PyErr_NoMemory(); } ((PyVoidScalarObject *)ret)->obval = destptr; - Py_SIZE((PyVoidScalarObject *)ret) = (int) memu; + Py_SET_SIZE((PyVoidScalarObject *)ret, (int) memu); ((PyVoidScalarObject *)ret)->descr = PyArray_DescrNewFromType(NPY_VOID); ((PyVoidScalarObject *)ret)->descr->elsize = (int) memu; diff --git a/numpy/core/src/umath/_rational_tests.c.src b/numpy/core/src/umath/_rational_tests.c.src index 651019a84..13e33d0a5 100644 --- a/numpy/core/src/umath/_rational_tests.c.src +++ b/numpy/core/src/umath/_rational_tests.c.src @@ -1158,7 +1158,7 @@ PyMODINIT_FUNC PyInit__rational_tests(void) { npyrational_arrfuncs.fill = npyrational_fill; npyrational_arrfuncs.fillwithscalar = npyrational_fillwithscalar; /* Left undefined: scanfunc, fromstr, sort, argsort */ - Py_TYPE(&npyrational_descr) = &PyArrayDescr_Type; + Py_SET_TYPE(&npyrational_descr, &PyArrayDescr_Type); npy_rational = PyArray_RegisterDataType(&npyrational_descr); if (npy_rational<0) { goto fail; diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index a59a9acf5..0cfa1cea7 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -1863,10 +1863,15 @@ NPY_NO_EXPORT void * #kind = isnan, isinf, isfinite, signbit# * #func = npy_isnan, npy_isinf, npy_isfinite, npy_signbit# **/ + +/**begin repeat2 + * #ISA = , _avx512_skx# + * #isa = simd, avx512_skx# + **/ NPY_NO_EXPORT void -@TYPE@_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) +@TYPE@_@kind@@ISA@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) { - if (!run_@kind@_simd_@TYPE@(args, dimensions, steps)) { + if (!run_@kind@_@isa@_@TYPE@(args, dimensions, steps)) { UNARY_LOOP { const @type@ in1 = *(@type@ *)ip1; *((npy_bool *)op1) = @func@(in1) != 0; @@ -1874,6 +1879,7 @@ NPY_NO_EXPORT void } npy_clear_floatstatus_barrier((char*)dimensions); } +/**end repeat2**/ /**end repeat1**/ NPY_NO_EXPORT void @@ -2131,6 +2137,14 @@ NPY_NO_EXPORT void } NPY_NO_EXPORT void +@TYPE@_frexp_avx512_skx(char **args, npy_intp const *dimensions, npy_intp const *steps, void *func) +{ + if (!run_unary_two_out_avx512_skx_frexp_@TYPE@(args, dimensions, steps)) { + @TYPE@_frexp(args, dimensions, steps, func); + } +} + +NPY_NO_EXPORT void @TYPE@_ldexp(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) { BINARY_LOOP { @@ -2141,6 +2155,14 @@ NPY_NO_EXPORT void } NPY_NO_EXPORT void +@TYPE@_ldexp_avx512_skx(char **args, const npy_intp *dimensions, const npy_intp *steps, void *func) +{ + if (!run_binary_avx512_skx_ldexp_@TYPE@(args, dimensions, steps)) { + @TYPE@_ldexp(args, dimensions, steps, func); + } +} + +NPY_NO_EXPORT void @TYPE@_ldexp_long(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) { /* diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src index 50a7ccfee..5dd49c465 100644 --- a/numpy/core/src/umath/loops.h.src +++ b/numpy/core/src/umath/loops.h.src @@ -274,8 +274,13 @@ NPY_NO_EXPORT void * #kind = isnan, isinf, isfinite, signbit, copysign, nextafter, spacing# * #func = npy_isnan, npy_isinf, npy_isfinite, npy_signbit, npy_copysign, nextafter, spacing# **/ + +/**begin repeat2 + * #ISA = , _avx512_skx# + **/ NPY_NO_EXPORT void -@TYPE@_@kind@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); +@TYPE@_@kind@@ISA@(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); +/**end repeat2**/ /**end repeat1**/ /**begin repeat1 @@ -334,9 +339,15 @@ NPY_NO_EXPORT void @TYPE@_frexp(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); NPY_NO_EXPORT void +@TYPE@_frexp_avx512_skx(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); + +NPY_NO_EXPORT void @TYPE@_ldexp(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); NPY_NO_EXPORT void +@TYPE@_ldexp_avx512_skx(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); + +NPY_NO_EXPORT void @TYPE@_ldexp_long(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)); #define @TYPE@_true_divide @TYPE@_divide diff --git a/numpy/core/src/umath/simd.inc.src b/numpy/core/src/umath/simd.inc.src index 6b0bcc3dc..48e89915c 100644 --- a/numpy/core/src/umath/simd.inc.src +++ b/numpy/core/src/umath/simd.inc.src @@ -1,4 +1,4 @@ -/* -*- c -*- */ + /* * This file is for the definitions of simd vectorized operations. @@ -120,6 +120,13 @@ nomemoverlap(char *ip, (nomemoverlap(args[0], steps[0] * dimensions[0], args[2], steps[2] * dimensions[0])) && \ (nomemoverlap(args[1], steps[1] * dimensions[0], args[2], steps[2] * dimensions[0]))) +#define IS_UNARY_TWO_OUT_SMALL_STEPS_AND_NOMEMOVERLAP \ + ((abs(steps[0]) < MAX_STEP_SIZE) && \ + (abs(steps[1]) < MAX_STEP_SIZE) && \ + (abs(steps[2]) < MAX_STEP_SIZE) && \ + (nomemoverlap(args[0], steps[0] * dimensions[0], args[2], steps[2] * dimensions[0])) && \ + (nomemoverlap(args[0], steps[0] * dimensions[0], args[1], steps[1] * dimensions[0]))) + /* * 1) Output should be contiguous, can handle strided input data * 2) Input step should be smaller than MAX_STEP_SIZE for performance @@ -294,6 +301,76 @@ run_binary_avx512f_@func@_@TYPE@(char **args, npy_intp const *dimensions, npy_in /**end repeat1**/ + +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS && @EXISTS@ +static NPY_INLINE NPY_GCC_TARGET_AVX512_SKX void +AVX512_SKX_ldexp_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps); + +static NPY_INLINE NPY_GCC_TARGET_AVX512_SKX void +AVX512_SKX_frexp_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps); +#endif + +static NPY_INLINE int +run_binary_avx512_skx_ldexp_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps) +{ +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS && @EXISTS@ + if (IS_BINARY_SMALL_STEPS_AND_NOMEMOVERLAP) { + AVX512_SKX_ldexp_@TYPE@(args, dimensions, steps); + return 1; + } + else + return 0; +#endif + return 0; +} + +static NPY_INLINE int +run_unary_two_out_avx512_skx_frexp_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps) +{ +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS && @EXISTS@ + if (IS_UNARY_TWO_OUT_SMALL_STEPS_AND_NOMEMOVERLAP) { + AVX512_SKX_frexp_@TYPE@(args, dimensions, steps); + return 1; + } + else + return 0; +#endif + return 0; +} +/**end repeat**/ + +/**begin repeat + * #type = npy_float, npy_double, npy_longdouble# + * #TYPE = FLOAT, DOUBLE, LONGDOUBLE# + * #EXISTS = 1, 1, 0# + */ + +/**begin repeat1 + * #func = isnan, isfinite, isinf, signbit# + */ + +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS && @EXISTS@ +static NPY_INLINE NPY_GCC_TARGET_AVX512_SKX void +AVX512_SKX_@func@_@TYPE@(npy_bool*, @type@*, const npy_intp n, const npy_intp stride); +#endif + +static NPY_INLINE int +run_@func@_avx512_skx_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps) +{ +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS && @EXISTS@ + if (IS_OUTPUT_BLOCKABLE_UNARY(sizeof(npy_bool), 64)) { + AVX512_SKX_@func@_@TYPE@((npy_bool*)args[1], (@type@*)args[0], dimensions[0], steps[0]); + return 1; + } + else { + return 0; + } +#endif + return 0; +} + + +/**end repeat1**/ /**end repeat**/ /**begin repeat @@ -1980,10 +2057,242 @@ static NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ @vtype@d * #vtype = __m512, __m512d# * #scale = 4, 8# * #vindextype = __m512i, __m256i# + * #vindexload = _mm512_loadu_si512, _mm256_loadu_si256# + * #episize = epi32, epi64# + */ + +/**begin repeat1 + * #func = isnan, isfinite, isinf, signbit# + * #IMM8 = 0x81, 0x99, 0x18, 0x04# + * #is_finite = 0, 1, 0, 0# + * #is_signbit = 0, 0, 0, 1# + */ +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS +static NPY_INLINE NPY_GCC_TARGET_AVX512_SKX void +AVX512_SKX_@func@_@TYPE@(npy_bool* op, @type@* ip, const npy_intp array_size, const npy_intp steps) +{ + const npy_intp stride_ip = steps/(npy_intp)sizeof(@type@); + npy_intp num_remaining_elements = array_size; + + @mask@ load_mask = avx512_get_full_load_mask_@vsuffix@(); +#if @is_signbit@ + @vtype@ signbit = _mm512_set1_@vsuffix@(-0.0); +#endif + + /* + * Note: while generally indices are npy_intp, we ensure that our maximum + * index will fit in an int32 as a precondition for this function via + * IS_OUTPUT_BLOCKABLE_UNARY + */ + + npy_int32 index_ip[@num_lanes@]; + for (npy_int32 ii = 0; ii < @num_lanes@; ii++) { + index_ip[ii] = ii*stride_ip; + } + @vindextype@ vindex_ip = @vindexload@((@vindextype@*)&index_ip[0]); + @vtype@ zeros_f = _mm512_setzero_@vsuffix@(); + __m512i ones = _mm512_set1_@episize@(1); + + while (num_remaining_elements > 0) { + if (num_remaining_elements < @num_lanes@) { + load_mask = avx512_get_partial_load_mask_@vsuffix@( + num_remaining_elements, @num_lanes@); + } + @vtype@ x1; + if (stride_ip == 1) { + x1 = avx512_masked_load_@vsuffix@(load_mask, ip); + } + else { + x1 = avx512_masked_gather_@vsuffix@(zeros_f, ip, vindex_ip, load_mask); + } +#if @is_signbit@ + x1 = _mm512_and_@vsuffix@(x1,signbit); +#endif + + @mask@ fpclassmask = _mm512_fpclass_@vsuffix@_mask(x1, @IMM8@); +#if @is_finite@ + fpclassmask = _mm512_knot(fpclassmask); +#endif + + __m128i out =_mm512_maskz_cvts@episize@_epi8(fpclassmask, ones); + _mm_mask_storeu_epi8(op, load_mask, out); + + ip += @num_lanes@*stride_ip; + op += @num_lanes@; + num_remaining_elements -= @num_lanes@; + } +} +#endif +/**end repeat1**/ +/**end repeat**/ + +/**begin repeat + * #type = npy_float, npy_double# + * #TYPE = FLOAT, DOUBLE# + * #num_lanes = 16, 8# + * #vsuffix = ps, pd# + * #mask = __mmask16, __mmask8# + * #vtype1 = __m512, __m512d# + * #vtype2 = __m512i, __m256i# + * #scale = 4, 8# + * #vindextype = __m512i, __m256i# * #vindexsize = 512, 256# * #vindexload = _mm512_loadu_si512, _mm256_loadu_si256# + * #vtype2_load = _mm512_maskz_loadu_epi32, _mm256_maskz_loadu_epi32# + * #vtype2_gather = _mm512_mask_i32gather_epi32, _mm256_mmask_i32gather_epi32# + * #vtype2_store = _mm512_mask_storeu_epi32, _mm256_mask_storeu_epi32# + * #vtype2_scatter = _mm512_mask_i32scatter_epi32, _mm256_mask_i32scatter_epi32# + * #setzero = _mm512_setzero_epi32, _mm256_setzero_si256# */ +#if defined HAVE_ATTRIBUTE_TARGET_AVX512_SKX_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS +static NPY_INLINE NPY_GCC_TARGET_AVX512_SKX void +AVX512_SKX_ldexp_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps) +{ + const npy_intp stride_ip1 = steps[0]/(npy_intp)sizeof(@type@); + const npy_intp stride_ip2 = steps[1]/(npy_intp)sizeof(int); + const npy_intp stride_op = steps[2]/(npy_intp)sizeof(@type@); + const npy_intp array_size = dimensions[0]; + npy_intp num_remaining_elements = array_size; + @type@* ip1 = (@type@*) args[0]; + int* ip2 = (int*) args[1]; + @type@* op = (@type@*) args[2]; + + @mask@ load_mask = avx512_get_full_load_mask_@vsuffix@(); + + /* + * Note: while generally indices are npy_intp, we ensure that our maximum index + * will fit in an int32 as a precondition for this function via + * IS_BINARY_SMALL_STEPS_AND_NOMEMOVERLAP + */ + + npy_int32 index_ip1[@num_lanes@], index_ip2[@num_lanes@], index_op[@num_lanes@]; + for (npy_int32 ii = 0; ii < @num_lanes@; ii++) { + index_ip1[ii] = ii*stride_ip1; + index_ip2[ii] = ii*stride_ip2; + index_op[ii] = ii*stride_op; + } + @vindextype@ vindex_ip1 = @vindexload@((@vindextype@*)&index_ip1[0]); + @vindextype@ vindex_ip2 = @vindexload@((@vindextype@*)&index_ip2[0]); + @vindextype@ vindex_op = @vindexload@((@vindextype@*)&index_op[0]); + @vtype1@ zeros_f = _mm512_setzero_@vsuffix@(); + @vtype2@ zeros = @setzero@(); + + while (num_remaining_elements > 0) { + if (num_remaining_elements < @num_lanes@) { + load_mask = avx512_get_partial_load_mask_@vsuffix@( + num_remaining_elements, @num_lanes@); + } + @vtype1@ x1; + @vtype2@ x2; + if (stride_ip1 == 1) { + x1 = avx512_masked_load_@vsuffix@(load_mask, ip1); + } + else { + x1 = avx512_masked_gather_@vsuffix@(zeros_f, ip1, vindex_ip1, load_mask); + } + if (stride_ip2 == 1) { + x2 = @vtype2_load@(load_mask, ip2); + } + else { + x2 = @vtype2_gather@(zeros, load_mask, vindex_ip2, ip2, 4); + } + + @vtype1@ out = _mm512_scalef_@vsuffix@(x1, _mm512_cvtepi32_@vsuffix@(x2)); + + if (stride_op == 1) { + _mm512_mask_storeu_@vsuffix@(op, load_mask, out); + } + else { + /* scatter! */ + _mm512_mask_i32scatter_@vsuffix@(op, load_mask, vindex_op, out, @scale@); + } + + ip1 += @num_lanes@*stride_ip1; + ip2 += @num_lanes@*stride_ip2; + op += @num_lanes@*stride_op; + num_remaining_elements -= @num_lanes@; + } +} + +static NPY_INLINE NPY_GCC_TARGET_AVX512_SKX void +AVX512_SKX_frexp_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *steps) +{ + const npy_intp stride_ip1 = steps[0]/(npy_intp)sizeof(@type@); + const npy_intp stride_op1 = steps[1]/(npy_intp)sizeof(@type@); + const npy_intp stride_op2 = steps[2]/(npy_intp)sizeof(int); + const npy_intp array_size = dimensions[0]; + npy_intp num_remaining_elements = array_size; + @type@* ip1 = (@type@*) args[0]; + @type@* op1 = (@type@*) args[1]; + int* op2 = (int*) args[2]; + + @mask@ load_mask = avx512_get_full_load_mask_@vsuffix@(); + + /* + * Note: while generally indices are npy_intp, we ensure that our maximum index + * will fit in an int32 as a precondition for this function via + * IS_BINARY_SMALL_STEPS_AND_NOMEMOVERLAP + */ + + npy_int32 index_ip1[@num_lanes@], index_op1[@num_lanes@], index_op2[@num_lanes@]; + for (npy_int32 ii = 0; ii < @num_lanes@; ii++) { + index_ip1[ii] = ii*stride_ip1; + index_op1[ii] = ii*stride_op1; + index_op2[ii] = ii*stride_op2; + } + @vindextype@ vindex_ip1 = @vindexload@((@vindextype@*)&index_ip1[0]); + @vindextype@ vindex_op1 = @vindexload@((@vindextype@*)&index_op1[0]); + @vindextype@ vindex_op2 = @vindexload@((@vindextype@*)&index_op2[0]); + @vtype1@ zeros_f = _mm512_setzero_@vsuffix@(); + + while (num_remaining_elements > 0) { + if (num_remaining_elements < @num_lanes@) { + load_mask = avx512_get_partial_load_mask_@vsuffix@( + num_remaining_elements, @num_lanes@); + } + @vtype1@ x1; + if (stride_ip1 == 1) { + x1 = avx512_masked_load_@vsuffix@(load_mask, ip1); + } + else { + x1 = avx512_masked_gather_@vsuffix@(zeros_f, ip1, vindex_ip1, load_mask); + } + + /* + * The x86 instructions vpgetmant and vpgetexp do not conform + * with NumPy's output for special floating points: NAN, +/-INF, +/-0.0 + * We mask these values with spmask to avoid invalid exceptions. + */ + @mask@ spmask =_mm512_knot(_mm512_fpclass_@vsuffix@_mask( + x1, 0b10011111)); + @vtype1@ out1 = _mm512_maskz_getmant_@vsuffix@( + spmask, x1, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src); + out1 = _mm512_mask_mov_@vsuffix@(x1, spmask, out1); + @vtype2@ out2 = _mm512_cvt@vsuffix@_epi32( + _mm512_maskz_add_@vsuffix@(spmask, _mm512_set1_@vsuffix@(1.0), + _mm512_maskz_getexp_@vsuffix@(spmask, x1))); + if (stride_op1 == 1) { + _mm512_mask_storeu_@vsuffix@(op1, load_mask, out1); + } + else { + _mm512_mask_i32scatter_@vsuffix@(op1, load_mask, vindex_op1, out1, @scale@); + } + if (stride_op2 == 1) { + @vtype2_store@(op2, load_mask, out2); + } + else { + @vtype2_scatter@(op2, load_mask, vindex_op2, out2, 4); + } + + ip1 += @num_lanes@*stride_ip1; + op1 += @num_lanes@*stride_op1; + op2 += @num_lanes@*stride_op2; + num_remaining_elements -= @num_lanes@; + } +} +#endif + /**begin repeat1 * #func = maximum, minimum# * #vectorf = max, min# @@ -2019,14 +2328,14 @@ AVX512F_@func@_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *s @vindextype@ vindex_ip1 = @vindexload@((@vindextype@*)&index_ip1[0]); @vindextype@ vindex_ip2 = @vindexload@((@vindextype@*)&index_ip2[0]); @vindextype@ vindex_op = @vindexload@((@vindextype@*)&index_op[0]); - @vtype@ zeros_f = _mm512_setzero_@vsuffix@(); + @vtype1@ zeros_f = _mm512_setzero_@vsuffix@(); while (num_remaining_elements > 0) { if (num_remaining_elements < @num_lanes@) { load_mask = avx512_get_partial_load_mask_@vsuffix@( num_remaining_elements, @num_lanes@); } - @vtype@ x1, x2; + @vtype1@ x1, x2; if (stride_ip1 == 1) { x1 = avx512_masked_load_@vsuffix@(load_mask, ip1); } @@ -2046,7 +2355,7 @@ AVX512F_@func@_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *s * this issue to conform with NumPy behaviour. */ @mask@ nan_mask = _mm512_cmp_@vsuffix@_mask(x1, x1, _CMP_NEQ_UQ); - @vtype@ out = _mm512_@vectorf@_@vsuffix@(x1, x2); + @vtype1@ out = _mm512_@vectorf@_@vsuffix@(x1, x2); out = _mm512_mask_blend_@vsuffix@(nan_mask, out, x1); if (stride_op == 1) { @@ -2064,8 +2373,8 @@ AVX512F_@func@_@TYPE@(char **args, npy_intp const *dimensions, npy_intp const *s } } #endif -/**end repeat**/ /**end repeat1**/ +/**end repeat**/ /**begin repeat * #ISA = FMA, AVX512F# diff --git a/numpy/core/tests/test_conversion_utils.py b/numpy/core/tests/test_conversion_utils.py index 3c3f9e6e1..e96113d09 100644 --- a/numpy/core/tests/test_conversion_utils.py +++ b/numpy/core/tests/test_conversion_utils.py @@ -154,3 +154,34 @@ class TestCastingConverter(StringConverterTestCase): self._check("safe", "NPY_SAFE_CASTING") self._check("same_kind", "NPY_SAME_KIND_CASTING") self._check("unsafe", "NPY_UNSAFE_CASTING") + + +class TestIntpConverter: + """ Tests of PyArray_IntpConverter """ + conv = mt.run_intp_converter + + def test_basic(self): + assert self.conv(1) == (1,) + assert self.conv((1, 2)) == (1, 2) + assert self.conv([1, 2]) == (1, 2) + assert self.conv(()) == () + + def test_none(self): + # once the warning expires, this will raise TypeError + with pytest.warns(DeprecationWarning): + assert self.conv(None) == () + + def test_float(self): + with pytest.raises(TypeError): + self.conv(1.0) + with pytest.raises(TypeError): + self.conv([1, 1.0]) + + def test_too_large(self): + with pytest.raises(ValueError): + self.conv(2**64) + + def test_too_many_dims(self): + assert self.conv([1]*32) == (1,)*32 + with pytest.raises(ValueError): + self.conv([1]*33) diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index e7965c0ca..91acd6ac3 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -3,6 +3,7 @@ import warnings import fnmatch import itertools import pytest +import sys from fractions import Fraction import numpy.core.umath as ncu @@ -771,6 +772,51 @@ class TestSpecialFloats: for dt in ['f', 'd', 'g']: assert_raises(FloatingPointError, np.reciprocal, np.array(-0.0, dtype=dt)) +class TestFPClass: + @pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4]) + def test_fpclass(self, stride): + arr_f64 = np.array([np.nan, -np.nan, np.inf, -np.inf, -1.0, 1.0, -0.0, 0.0, 2.2251e-308, -2.2251e-308], dtype='d') + arr_f32 = np.array([np.nan, -np.nan, np.inf, -np.inf, -1.0, 1.0, -0.0, 0.0, 1.4013e-045, -1.4013e-045], dtype='f') + nan = np.array([True, True, False, False, False, False, False, False, False, False]) + inf = np.array([False, False, True, True, False, False, False, False, False, False]) + sign = np.array([False, True, False, True, True, False, True, False, False, True]) + finite = np.array([False, False, False, False, True, True, True, True, True, True]) + assert_equal(np.isnan(arr_f32[::stride]), nan[::stride]) + assert_equal(np.isnan(arr_f64[::stride]), nan[::stride]) + assert_equal(np.isinf(arr_f32[::stride]), inf[::stride]) + assert_equal(np.isinf(arr_f64[::stride]), inf[::stride]) + assert_equal(np.signbit(arr_f32[::stride]), sign[::stride]) + assert_equal(np.signbit(arr_f64[::stride]), sign[::stride]) + assert_equal(np.isfinite(arr_f32[::stride]), finite[::stride]) + assert_equal(np.isfinite(arr_f64[::stride]), finite[::stride]) + +class TestLDExp: + @pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4]) + @pytest.mark.parametrize("dtype", ['f', 'd']) + def test_ldexp(self, dtype, stride): + mant = np.array([0.125, 0.25, 0.5, 1., 1., 2., 4., 8.], dtype=dtype) + exp = np.array([3, 2, 1, 0, 0, -1, -2, -3], dtype='i') + out = np.zeros(8, dtype=dtype) + assert_equal(np.ldexp(mant[::stride], exp[::stride], out=out[::stride]), np.ones(8, dtype=dtype)[::stride]) + assert_equal(out[::stride], np.ones(8, dtype=dtype)[::stride]) + +class TestFRExp: + @pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4]) + @pytest.mark.parametrize("dtype", ['f', 'd']) + @pytest.mark.skipif(not sys.platform.startswith('linux'), + reason="np.frexp gives different answers for NAN/INF on windows and linux") + def test_frexp(self, dtype, stride): + arr = np.array([np.nan, np.nan, np.inf, -np.inf, 0.0, -0.0, 1.0, -1.0], dtype=dtype) + mant_true = np.array([np.nan, np.nan, np.inf, -np.inf, 0.0, -0.0, 0.5, -0.5], dtype=dtype) + exp_true = np.array([0, 0, 0, 0, 0, 0, 1, 1], dtype='i') + out_mant = np.ones(8, dtype=dtype) + out_exp = 2*np.ones(8, dtype='i') + mant, exp = np.frexp(arr[::stride], out=(out_mant[::stride], out_exp[::stride])) + assert_equal(mant_true[::stride], mant) + assert_equal(exp_true[::stride], exp) + assert_equal(out_mant[::stride], mant_true[::stride]) + assert_equal(out_exp[::stride], exp_true[::stride]) + # func : [maxulperror, low, high] avx_ufuncs = {'sqrt' :[1, 0., 100.], 'absolute' :[0, -100., 100.], diff --git a/numpy/f2py/rules.py b/numpy/f2py/rules.py index 6750bf705..ecfc71ae3 100755 --- a/numpy/f2py/rules.py +++ b/numpy/f2py/rules.py @@ -194,7 +194,7 @@ PyMODINIT_FUNC PyInit_#modulename#(void) { \tint i; \tPyObject *m,*d, *s, *tmp; \tm = #modulename#_module = PyModule_Create(&moduledef); -\tPy_TYPE(&PyFortran_Type) = &PyType_Type; +\tPy_SET_TYPE(&PyFortran_Type, &PyType_Type); \timport_array(); \tif (PyErr_Occurred()) \t\t{PyErr_SetString(PyExc_ImportError, \"can't initialize module #modulename# (failed to import numpy)\"); return m;} diff --git a/numpy/f2py/tests/src/array_from_pyobj/wrapmodule.c b/numpy/f2py/tests/src/array_from_pyobj/wrapmodule.c index 83c0da2cf..0db33e714 100644 --- a/numpy/f2py/tests/src/array_from_pyobj/wrapmodule.c +++ b/numpy/f2py/tests/src/array_from_pyobj/wrapmodule.c @@ -144,7 +144,7 @@ static struct PyModuleDef moduledef = { PyMODINIT_FUNC PyInit_test_array_from_pyobj_ext(void) { PyObject *m,*d, *s; m = wrap_module = PyModule_Create(&moduledef); - Py_TYPE(&PyFortran_Type) = &PyType_Type; + Py_SET_TYPE(&PyFortran_Type, &PyType_Type); import_array(); if (PyErr_Occurred()) Py_FatalError("can't initialize module wrap (failed to import numpy)"); diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 48b0a0830..7a23aeab7 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1334,7 +1334,7 @@ def interp(x, xp, fp, left=None, right=None, period=None): See Also -------- - scipy.interpolate + scipy.interpolate Notes ----- @@ -3273,10 +3273,17 @@ def _sinc_dispatcher(x): @array_function_dispatch(_sinc_dispatcher) def sinc(x): - """ - Return the sinc function. + r""" + Return the normalized sinc function. + + The sinc function is :math:`\sin(\pi x)/(\pi x)`. + + .. note:: - The sinc function is :math:`\\sin(\\pi x)/(\\pi x)`. + Note the normalization factor of ``pi`` used in the definition. + This is the most commonly used definition in signal processing. + Use ``sinc(x / np.pi)`` to obtain the unnormalized sinc function + :math:`\sin(x)/(x)` that is more common in mathematics. Parameters ---------- diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index 111c2790c..2e54dce5f 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -646,10 +646,14 @@ cdef class Generator: if abs(p_sum - 1.) > atol: raise ValueError("probabilities do not sum to 1") - shape = size - if shape is not None: + # `shape == None` means `shape == ()`, but with scalar unpacking at the + # end + is_scalar = size is None + if not is_scalar: + shape = size size = np.prod(shape, dtype=np.intp) else: + shape = () size = 1 # Actual sampling @@ -733,10 +737,9 @@ cdef class Generator: idx_data[j - pop_size_i + size_i] = j if shuffle: self._shuffle_int(size_i, 1, idx_data) - if shape is not None: - idx.shape = shape + idx.shape = shape - if shape is None and isinstance(idx, np.ndarray): + if is_scalar and isinstance(idx, np.ndarray): # In most cases a scalar will have been made an array idx = idx.item(0) @@ -744,7 +747,7 @@ cdef class Generator: if a.ndim == 0: return idx - if shape is not None and idx.ndim == 0: + if not is_scalar and idx.ndim == 0: # If size == () then the user requested a 0-d array as opposed to # a scalar object when size is None. However a[idx] is always a # scalar and not an array. So this makes sure the result is an diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index f2805871d..8820a6e09 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -930,10 +930,14 @@ cdef class RandomState: if abs(p_sum - 1.) > atol: raise ValueError("probabilities do not sum to 1") - shape = size - if shape is not None: + # `shape == None` means `shape == ()`, but with scalar unpacking at the + # end + is_scalar = size is None + if not is_scalar: + shape = size size = np.prod(shape, dtype=np.intp) else: + shape = () size = 1 # Actual sampling @@ -977,10 +981,9 @@ cdef class RandomState: idx = found else: idx = self.permutation(pop_size)[:size] - if shape is not None: - idx.shape = shape + idx.shape = shape - if shape is None and isinstance(idx, np.ndarray): + if is_scalar and isinstance(idx, np.ndarray): # In most cases a scalar will have been made an array idx = idx.item(0) @@ -988,7 +991,7 @@ cdef class RandomState: if a.ndim == 0: return idx - if shape is not None and idx.ndim == 0: + if not is_scalar and idx.ndim == 0: # If size == () then the user requested a 0-d array as opposed to # a scalar object when size is None. However a[idx] is always a # scalar and not an array. So this makes sure the result is an |
