diff options
29 files changed, 460 insertions, 185 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst index f6e8c09ef..a7c0e2852 100644 --- a/doc/release/1.10.0-notes.rst +++ b/doc/release/1.10.0-notes.rst @@ -61,6 +61,9 @@ C API The changes to *swapaxes* also apply to the *PyArray_SwapAxes* C function, which now returns a view in all cases. +The dtype structure (PyArray_Descr) has a new member at the end to cache +its hash value. This shouldn't affect any well-written applications. + recarray field return types ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Previously the returned types for recarray fields accessed by attribute and by @@ -153,6 +156,12 @@ interpolation behavior. NumPy arrays are supported as input for ``pad_width``, and an exception is raised if its values are not of integral type. +*np.argmax* and *np.argmin* now support an ``out`` argument +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The ``out`` parameter was added to *np.argmax* and *np.argmin* for consistency +with *ndarray.argmax* and *ndarray.argmin*. The new parameter behaves exactly +as it does in those methods. + More system C99 complex functions detected and used ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ All of the functions ``in complex.h`` are now detected. There are new diff --git a/numpy/__init__.py b/numpy/__init__.py index 39933e8ca..d4ef54d83 100644 --- a/numpy/__init__.py +++ b/numpy/__init__.py @@ -224,3 +224,4 @@ else: import warnings warnings.filterwarnings("ignore", message="numpy.dtype size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed") + warnings.filterwarnings("ignore", message="numpy.ndarray size changed") diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 7dd8c5649..4cc626ca9 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -2982,7 +2982,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__setstate__', add_newdoc('numpy.core.multiarray', 'ndarray', ('all', """ - a.all(axis=None, out=None) + a.all(axis=None, out=None, keepdims=False) Returns True if all elements evaluate to True. @@ -2997,7 +2997,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('all', add_newdoc('numpy.core.multiarray', 'ndarray', ('any', """ - a.any(axis=None, out=None) + a.any(axis=None, out=None, keepdims=False) Returns True if any of the elements of `a` evaluate to True. @@ -3198,9 +3198,10 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('choose', add_newdoc('numpy.core.multiarray', 'ndarray', ('clip', """ - a.clip(a_min, a_max, out=None) + a.clip(min=None, max=None, out=None) - Return an array whose values are limited to ``[a_min, a_max]``. + Return an array whose values are limited to ``[min, max]``. + One of max or min must be given. Refer to `numpy.clip` for full documentation. @@ -3656,7 +3657,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('max', add_newdoc('numpy.core.multiarray', 'ndarray', ('mean', """ - a.mean(axis=None, dtype=None, out=None) + a.mean(axis=None, dtype=None, out=None, keepdims=False) Returns the average of the array elements along given axis. @@ -3671,7 +3672,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('mean', add_newdoc('numpy.core.multiarray', 'ndarray', ('min', """ - a.min(axis=None, out=None) + a.min(axis=None, out=None, keepdims=False) Return the minimum along a given axis. @@ -3769,7 +3770,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('nonzero', add_newdoc('numpy.core.multiarray', 'ndarray', ('prod', """ - a.prod(axis=None, dtype=None, out=None) + a.prod(axis=None, dtype=None, out=None, keepdims=False) Return the product of the array elements over the given axis @@ -4300,7 +4301,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('squeeze', add_newdoc('numpy.core.multiarray', 'ndarray', ('std', """ - a.std(axis=None, dtype=None, out=None, ddof=0) + a.std(axis=None, dtype=None, out=None, ddof=0, keepdims=False) Returns the standard deviation of the array elements along given axis. @@ -4315,7 +4316,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('std', add_newdoc('numpy.core.multiarray', 'ndarray', ('sum', """ - a.sum(axis=None, dtype=None, out=None) + a.sum(axis=None, dtype=None, out=None, keepdims=False) Return the sum of the array elements over the given axis. @@ -4547,7 +4548,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('transpose', add_newdoc('numpy.core.multiarray', 'ndarray', ('var', """ - a.var(axis=None, dtype=None, out=None, ddof=0) + a.var(axis=None, dtype=None, out=None, ddof=0, keepdims=False) Returns the variance of the array elements, along given axis. diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index b0c141178..549647df2 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -900,7 +900,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None): return argsort(axis, kind, order) -def argmax(a, axis=None): +def argmax(a, axis=None, out=None): """ Returns the indices of the maximum values along an axis. @@ -911,6 +911,9 @@ def argmax(a, axis=None): axis : int, optional By default, the index is into the flattened array, otherwise along the specified axis. + out : array, optional + If provided, the result will be inserted into this array. It should + be of the appropriate shape and dtype. Returns ------- @@ -953,11 +956,11 @@ def argmax(a, axis=None): try: argmax = a.argmax except AttributeError: - return _wrapit(a, 'argmax', axis) - return argmax(axis) + return _wrapit(a, 'argmax', axis, out) + return argmax(axis, out) -def argmin(a, axis=None): +def argmin(a, axis=None, out=None): """ Returns the indices of the minimum values along an axis. @@ -968,6 +971,9 @@ def argmin(a, axis=None): axis : int, optional By default, the index is into the flattened array, otherwise along the specified axis. + out : array, optional + If provided, the result will be inserted into this array. It should + be of the appropriate shape and dtype. Returns ------- @@ -1010,8 +1016,8 @@ def argmin(a, axis=None): try: argmin = a.argmin except AttributeError: - return _wrapit(a, 'argmin', axis) - return argmin(axis) + return _wrapit(a, 'argmin', axis, out) + return argmin(axis, out) def searchsorted(a, v, side='left', sorter=None): diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h index 78f79d5fe..edae27c72 100644 --- a/numpy/core/include/numpy/ndarraytypes.h +++ b/numpy/core/include/numpy/ndarraytypes.h @@ -619,6 +619,10 @@ typedef struct _PyArray_Descr { * for NumPy 1.7.0. */ NpyAuxData *c_metadata; + /* Cached hash value (-1 if not yet computed). + * This was added for NumPy 2.0.0. + */ + npy_hash_t hash; } PyArray_Descr; typedef struct _arr_descr { diff --git a/numpy/core/include/numpy/npy_3kcompat.h b/numpy/core/include/numpy/npy_3kcompat.h index 8a9109c5c..ef5b5694c 100644 --- a/numpy/core/include/numpy/npy_3kcompat.h +++ b/numpy/core/include/numpy/npy_3kcompat.h @@ -486,19 +486,6 @@ NpyCapsule_Check(PyObject *ptr) #endif -/* - * Hash value compatibility. - * As of Python 3.2 hash values are of type Py_hash_t. - * Previous versions use C long. - */ -#if PY_VERSION_HEX < 0x03020000 -typedef long npy_hash_t; -#define NPY_SIZEOF_HASH_T NPY_SIZEOF_LONG -#else -typedef Py_hash_t npy_hash_t; -#define NPY_SIZEOF_HASH_T NPY_SIZEOF_INTP -#endif - #ifdef __cplusplus } #endif diff --git a/numpy/core/include/numpy/npy_common.h b/numpy/core/include/numpy/npy_common.h index 92b03d20c..eff5dd339 100644 --- a/numpy/core/include/numpy/npy_common.h +++ b/numpy/core/include/numpy/npy_common.h @@ -317,6 +317,19 @@ typedef float npy_float; typedef double npy_double; /* + * Hash value compatibility. + * As of Python 3.2 hash values are of type Py_hash_t. + * Previous versions use C long. + */ +#if PY_VERSION_HEX < 0x03020000 +typedef long npy_hash_t; +#define NPY_SIZEOF_HASH_T NPY_SIZEOF_LONG +#else +typedef Py_hash_t npy_hash_t; +#define NPY_SIZEOF_HASH_T NPY_SIZEOF_INTP +#endif + +/* * Disabling C99 complex usage: a lot of C code in numpy/scipy rely on being * able to do .real/.imag. Will have to convert code first. */ diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 7f0649158..11b443cf8 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -758,6 +758,7 @@ def configuration(parent_package='',top_path=None): join('src', 'multiarray', 'ucsnarrow.h'), join('src', 'multiarray', 'usertypes.h'), join('src', 'multiarray', 'vdot.h'), + join('src', 'private', 'npy_config.h'), join('src', 'private', 'templ_common.h.src'), join('src', 'private', 'lowlevel_strided_loops.h'), join('include', 'numpy', 'arrayobject.h'), diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 05b843fc4..8287c2268 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -4031,6 +4031,8 @@ static PyArray_Descr @from@_Descr = { NULL, /* c_metadata */ NULL, + /* hash */ + -1, }; /**end repeat**/ @@ -4172,6 +4174,8 @@ NPY_NO_EXPORT PyArray_Descr @from@_Descr = { NULL, /* c_metadata */ NULL, + /* hash */ + -1, }; /**end repeat**/ diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c index edcca9857..d4a08a4ee 100644 --- a/numpy/core/src/multiarray/calculation.c +++ b/numpy/core/src/multiarray/calculation.c @@ -618,7 +618,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out) } /* arr.real = a.real.round(decimals) */ - part = PyObject_GetAttrString(arr, "real"); + part = PyObject_GetAttrString(a, "real"); if (part == NULL) { Py_DECREF(arr); return NULL; @@ -639,7 +639,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out) } /* arr.imag = a.imag.round(decimals) */ - part = PyObject_GetAttrString(arr, "imag"); + part = PyObject_GetAttrString(a, "imag"); if (part == NULL) { Py_DECREF(arr); return NULL; diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c index 816778b91..a5f3b3d55 100644 --- a/numpy/core/src/multiarray/common.c +++ b/numpy/core/src/multiarray/common.c @@ -684,7 +684,16 @@ _IsAligned(PyArrayObject *ap) /* alignment 1 types should have a efficient alignment for copy loops */ if (PyArray_ISFLEXIBLE(ap) || PyArray_ISSTRING(ap)) { - alignment = NPY_MAX_COPY_ALIGNMENT; + npy_intp itemsize = PyArray_ITEMSIZE(ap); + /* power of two sizes may be loaded in larger moves */ + if (((itemsize & (itemsize - 1)) == 0)) { + alignment = itemsize > NPY_MAX_COPY_ALIGNMENT ? + NPY_MAX_COPY_ALIGNMENT : itemsize; + } + else { + /* if not power of two it will be accessed bytewise */ + alignment = 1; + } } if (alignment == 1) { diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 0993190b7..bbcd5da36 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -1591,6 +1591,7 @@ PyArray_DescrNew(PyArray_Descr *base) } Py_XINCREF(newdescr->typeobj); Py_XINCREF(newdescr->metadata); + newdescr->hash = -1; return newdescr; } @@ -1994,6 +1995,8 @@ arraydescr_names_set(PyArray_Descr *self, PyObject *val) return -1; } } + /* Invalidate cached hash value */ + self->hash = -1; /* Update dictionary keys in fields */ new_names = PySequence_Tuple(val); new_fields = PyDict_New(); @@ -2443,6 +2446,8 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) version); return NULL; } + /* Invalidate cached hash value */ + self->hash = -1; if (version == 1 || version == 0) { if (fields != Py_None) { diff --git a/numpy/core/src/multiarray/hashdescr.c b/numpy/core/src/multiarray/hashdescr.c index 29d69fddb..3981ccc0e 100644 --- a/numpy/core/src/multiarray/hashdescr.c +++ b/numpy/core/src/multiarray/hashdescr.c @@ -28,7 +28,7 @@ static int _is_array_descr_builtin(PyArray_Descr* descr); static int _array_descr_walk(PyArray_Descr* descr, PyObject *l); -static int _array_descr_walk_fields(PyObject* fields, PyObject* l); +static int _array_descr_walk_fields(PyObject *names, PyObject* fields, PyObject* l); static int _array_descr_builtin(PyArray_Descr* descr, PyObject *l); /* @@ -86,7 +86,6 @@ static int _array_descr_builtin(PyArray_Descr* descr, PyObject *l) "(Hash) Error while computing builting hash"); goto clean_t; } - Py_INCREF(item); PyList_Append(l, item); } @@ -104,18 +103,35 @@ clean_t: * * Return 0 on success */ -static int _array_descr_walk_fields(PyObject* fields, PyObject* l) +static int _array_descr_walk_fields(PyObject *names, PyObject* fields, PyObject* l) { - PyObject *key, *value, *foffset, *fdescr; + PyObject *key, *value, *foffset, *fdescr, *ftitle; Py_ssize_t pos = 0; int st; - while (PyDict_Next(fields, &pos, &key, &value)) { + if (!PyTuple_Check(names)) { + PyErr_SetString(PyExc_SystemError, + "(Hash) names is not a tuple ???"); + return -1; + } + if (!PyDict_Check(fields)) { + PyErr_SetString(PyExc_SystemError, + "(Hash) fields is not a dict ???"); + return -1; + } + + for (pos = 0; pos < PyTuple_GET_SIZE(names); pos++) { /* * For each field, add the key + descr + offset to l */ - + key = PyTuple_GET_ITEM(names, pos); + value = PyDict_GetItem(fields, key); /* XXX: are those checks necessary ? */ + if (value == NULL) { + PyErr_SetString(PyExc_SystemError, + "(Hash) names and fields inconsistent ???"); + return -1; + } if (!PyUString_Check(key)) { PyErr_SetString(PyExc_SystemError, "(Hash) key of dtype dict not a string ???"); @@ -126,15 +142,14 @@ static int _array_descr_walk_fields(PyObject* fields, PyObject* l) "(Hash) value of dtype dict not a dtype ???"); return -1; } - if (PyTuple_Size(value) < 2) { + if (PyTuple_GET_SIZE(value) < 2) { PyErr_SetString(PyExc_SystemError, "(Hash) Less than 2 items in dtype dict ???"); return -1; } - Py_INCREF(key); PyList_Append(l, key); - fdescr = PyTuple_GetItem(value, 0); + fdescr = PyTuple_GET_ITEM(value, 0); if (!PyArray_DescrCheck(fdescr)) { PyErr_SetString(PyExc_SystemError, "(Hash) First item in compound dtype tuple not a descr ???"); @@ -149,16 +164,20 @@ static int _array_descr_walk_fields(PyObject* fields, PyObject* l) } } - foffset = PyTuple_GetItem(value, 1); + foffset = PyTuple_GET_ITEM(value, 1); if (!PyInt_Check(foffset)) { PyErr_SetString(PyExc_SystemError, "(Hash) Second item in compound dtype tuple not an int ???"); return -1; } else { - Py_INCREF(foffset); PyList_Append(l, foffset); } + + if (PyTuple_GET_SIZE(value) > 2) { + ftitle = PyTuple_GET_ITEM(value, 2); + PyList_Append(l, ftitle); + } } return 0; @@ -186,12 +205,10 @@ static int _array_descr_walk_subarray(PyArray_ArrayDescr* adescr, PyObject *l) "(Hash) Error while getting shape item of subarray dtype ???"); return -1; } - Py_INCREF(item); PyList_Append(l, item); } } else if (PyInt_Check(adescr->shape)) { - Py_INCREF(adescr->shape); PyList_Append(l, adescr->shape); } else { @@ -219,12 +236,7 @@ static int _array_descr_walk(PyArray_Descr* descr, PyObject *l) } else { if(descr->fields != NULL && descr->fields != Py_None) { - if (!PyDict_Check(descr->fields)) { - PyErr_SetString(PyExc_SystemError, - "(Hash) fields is not a dict ???"); - return -1; - } - st = _array_descr_walk_fields(descr->fields, l); + st = _array_descr_walk_fields(descr->names, descr->fields, l); if (st) { return -1; } @@ -256,44 +268,31 @@ static int _PyArray_DescrHashImp(PyArray_Descr *descr, npy_hash_t *hash) st = _array_descr_walk(descr, l); if (st) { - goto clean_l; + Py_DECREF(l); + return -1; } /* * Convert the list to tuple and compute the tuple hash using python * builtin function */ - tl = PyTuple_New(PyList_Size(l)); - for(i = 0; i < PyList_Size(l); ++i) { - item = PyList_GetItem(l, i); - if (item == NULL) { - PyErr_SetString(PyExc_SystemError, - "(Hash) Error while translating the list into a tuple " \ - "(NULL item)"); - goto clean_tl; - } - PyTuple_SetItem(tl, i, item); - } + tl = PyList_AsTuple(l); + Py_DECREF(l); + if (tl == NULL) + return -1; *hash = PyObject_Hash(tl); + Py_DECREF(tl); if (*hash == -1) { /* XXX: does PyObject_Hash set an exception on failure ? */ #if 0 PyErr_SetString(PyExc_SystemError, "(Hash) Error while hashing final tuple"); #endif - goto clean_tl; + return -1; } - Py_DECREF(tl); - Py_DECREF(l); return 0; - -clean_tl: - Py_DECREF(tl); -clean_l: - Py_DECREF(l); - return -1; } NPY_NO_EXPORT npy_hash_t @@ -301,7 +300,6 @@ PyArray_DescrHash(PyObject* odescr) { PyArray_Descr *descr; int st; - npy_hash_t hash; if (!PyArray_DescrCheck(odescr)) { PyErr_SetString(PyExc_ValueError, @@ -310,10 +308,12 @@ PyArray_DescrHash(PyObject* odescr) } descr = (PyArray_Descr*)odescr; - st = _PyArray_DescrHashImp(descr, &hash); - if (st) { - return -1; + if (descr->hash == -1) { + st = _PyArray_DescrHashImp(descr, &descr->hash); + if (st) { + return -1; + } } - return hash; + return descr->hash; } diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index a29c47555..c59a125e0 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -546,9 +546,9 @@ NPY_NO_EXPORT PyObject * PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) { npy_intp *counts; - npy_intp n, n_outer, i, j, k, chunk, total; - npy_intp tmp; - int nd; + npy_intp n, n_outer, i, j, k, chunk; + npy_intp total = 0; + npy_bool broadcast = NPY_FALSE; PyArrayObject *repeats = NULL; PyObject *ap = NULL; PyArrayObject *ret = NULL; @@ -558,34 +558,35 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) if (repeats == NULL) { return NULL; } - nd = PyArray_NDIM(repeats); + + /* + * Scalar and size 1 'repeat' arrays broadcast to any shape, for all + * other inputs the dimension must match exactly. + */ + if (PyArray_NDIM(repeats) == 0 || PyArray_SIZE(repeats) == 1) { + broadcast = NPY_TRUE; + } + counts = (npy_intp *)PyArray_DATA(repeats); - if ((ap=PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY))==NULL) { + if ((ap = PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY)) == NULL) { Py_DECREF(repeats); return NULL; } aop = (PyArrayObject *)ap; - if (nd == 1) { - n = PyArray_DIMS(repeats)[0]; - } - else { - /* nd == 0 */ - n = PyArray_DIMS(aop)[axis]; - } - if (PyArray_DIMS(aop)[axis] != n) { - PyErr_SetString(PyExc_ValueError, - "a.shape[axis] != len(repeats)"); + n = PyArray_DIM(aop, axis); + + if (!broadcast && PyArray_SIZE(repeats) != n) { + PyErr_Format(PyExc_ValueError, + "operands could not be broadcast together " + "with shape (%zd,) (%zd,)", n, PyArray_DIM(repeats, 0)); goto fail; } - - if (nd == 0) { - total = counts[0]*n; + if (broadcast) { + total = counts[0] * n; } else { - - total = 0; for (j = 0; j < n; j++) { if (counts[j] < 0) { PyErr_SetString(PyExc_ValueError, "count < 0"); @@ -595,7 +596,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) } } - /* Construct new array */ PyArray_DIMS(aop)[axis] = total; Py_INCREF(PyArray_DESCR(aop)); @@ -623,7 +623,7 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) } for (i = 0; i < n_outer; i++) { for (j = 0; j < n; j++) { - tmp = nd ? counts[j] : counts[0]; + npy_intp tmp = broadcast ? counts[0] : counts[j]; for (k = 0; k < tmp; k++) { memcpy(new_data, old_data, chunk); new_data += chunk; diff --git a/numpy/core/src/private/npy_config.h b/numpy/core/src/private/npy_config.h index 6e98dc7e9..580b00706 100644 --- a/numpy/core/src/private/npy_config.h +++ b/numpy/core/src/private/npy_config.h @@ -3,6 +3,7 @@ #include "config.h" #include "numpy/numpyconfig.h" +#include "numpy/npy_cpu.h" /* * largest alignment the copy loops might require @@ -13,7 +14,11 @@ * amd64 is not harmed much by the bloat as the system provides 16 byte * alignment by default. */ +#if (defined NPY_CPU_X86 || defined _WIN32) +#define NPY_MAX_COPY_ALIGNMENT 8 +#else #define NPY_MAX_COPY_ALIGNMENT 16 +#endif /* blacklist */ diff --git a/numpy/core/src/umath/scalarmath.c.src b/numpy/core/src/umath/scalarmath.c.src index e2c8137b3..e4fc617a5 100644 --- a/numpy/core/src/umath/scalarmath.c.src +++ b/numpy/core/src/umath/scalarmath.c.src @@ -485,10 +485,8 @@ static void /**begin repeat - * #name = byte, short, int, long, longlong, - * float, double, longdouble# - * #type = npy_byte, npy_short, npy_int, npy_long, npy_longlong, - * npy_float, npy_double, npy_longdouble# + * #name = byte, short, int, long, longlong# + * #type = npy_byte, npy_short, npy_int, npy_long, npy_longlong# */ static void @name@_ctype_absolute(@type@ a, @type@ *out) @@ -497,6 +495,18 @@ static void } /**end repeat**/ +/**begin repeat + * #name = float, double, longdouble# + * #type = npy_float, npy_double, npy_longdouble# + * #c = f,,l# + */ +static void +@name@_ctype_absolute(@type@ a, @type@ *out) +{ + *out = npy_fabs@c@(a); +} +/**end repeat**/ + static void half_ctype_absolute(npy_half a, npy_half *out) { diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 36201b3ea..9f89d71c2 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -5547,7 +5547,7 @@ ufunc_get_identity(PyUFuncObject *ufunc) case PyUFunc_Zero: return PyInt_FromLong(0); } - return Py_None; + Py_RETURN_NONE; } static PyObject * diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 3a255b038..852660432 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -125,6 +125,21 @@ class TestRecord(TestCase): 'titles': ['RRed pixel', 'Blue pixel']}) assert_dtype_not_equal(a, b) + def test_mutate(self): + # Mutating a dtype should reset the cached hash value + a = np.dtype([('yo', np.int)]) + b = np.dtype([('yo', np.int)]) + c = np.dtype([('ye', np.int)]) + assert_dtype_equal(a, b) + assert_dtype_not_equal(a, c) + a.names = ['ye'] + assert_dtype_equal(a, c) + assert_dtype_not_equal(a, b) + state = b.__reduce__()[2] + a.__setstate__(state) + assert_dtype_equal(a, b) + assert_dtype_not_equal(a, c) + def test_not_lists(self): """Test if an appropriate exception is raised when passing bad values to the dtype constructor. diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 937ca9d72..314adf4d1 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -69,6 +69,17 @@ class TestFlags(TestCase): assert_equal(self.a.flags.aligned, True) assert_equal(self.a.flags.updateifcopy, False) + def test_string_align(self): + a = np.zeros(4, dtype=np.dtype('|S4')) + assert_(a.flags.aligned) + # not power of two are accessed bytewise and thus considered aligned + a = np.zeros(5, dtype=np.dtype('|S4')) + assert_(a.flags.aligned) + + def test_void_align(self): + a = np.zeros(4, dtype=np.dtype([("a", "i4"), ("b", "i4")])) + assert_(a.flags.aligned) + class TestHash(TestCase): # see #3793 def test_int(self): @@ -857,11 +868,22 @@ class TestBool(TestCase): self.assertEqual(np.count_nonzero(a), builtins.sum(a.tolist())) class TestMethods(TestCase): - def test_test_round(self): - assert_equal(array([1.2, 1.5]).round(), [1, 2]) - assert_equal(array(1.5).round(), 2) - assert_equal(array([12.2, 15.5]).round(-1), [10, 20]) - assert_equal(array([12.15, 15.51]).round(1), [12.2, 15.5]) + def test_round(self): + def check_round(arr, expected, *round_args): + assert_equal(arr.round(*round_args), expected) + # With output array + out = np.zeros_like(arr) + res = arr.round(*round_args, out=out) + assert_equal(out, expected) + assert_equal(out, res) + + check_round(array([1.2, 1.5]), [1, 2]) + check_round(array(1.5), 2) + check_round(array([12.2, 15.5]), [10, 20], -1) + check_round(array([12.15, 15.51]), [12.2, 15.5], 1) + # Complex rounding + check_round(array([4.5 + 1.5j]), [4 + 2j]) + check_round(array([12.5 + 15.5j]), [10 + 20j], -1) def test_transpose(self): a = array([[1, 2], [3, 4]]) @@ -2638,6 +2660,22 @@ class TestArgmax(TestCase): d[5942] = "as" assert_equal(d.argmax(), 5942) + def test_np_vs_ndarray(self): + # make sure both ndarray.argmax and numpy.argmax support out/axis args + a = np.random.normal(size=(2,3)) + + #check positional args + out1 = zeros(2, dtype=int) + out2 = zeros(2, dtype=int) + assert_equal(a.argmax(1, out1), np.argmax(a, 1, out2)) + assert_equal(out1, out2) + + #check keyword args + out1 = zeros(3, dtype=int) + out2 = zeros(3, dtype=int) + assert_equal(a.argmax(out=out1, axis=0), np.argmax(a, out=out2, axis=0)) + assert_equal(out1, out2) + class TestArgmin(TestCase): @@ -2748,6 +2786,22 @@ class TestArgmin(TestCase): d[6001] = "0" assert_equal(d.argmin(), 6001) + def test_np_vs_ndarray(self): + # make sure both ndarray.argmin and numpy.argmin support out/axis args + a = np.random.normal(size=(2,3)) + + #check positional args + out1 = zeros(2, dtype=int) + out2 = ones(2, dtype=int) + assert_equal(a.argmin(1, out1), np.argmin(a, 1, out2)) + assert_equal(out1, out2) + + #check keyword args + out1 = zeros(3, dtype=int) + out2 = ones(3, dtype=int) + assert_equal(a.argmin(out=out1, axis=0), np.argmin(a, out=out2, axis=0)) + assert_equal(out1, out2) + class TestMinMax(TestCase): def test_scalar(self): diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 19c8d4457..fa2f52a23 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -10,6 +10,7 @@ import warnings import tempfile from os import path from io import BytesIO +from itertools import chain import numpy as np from numpy.testing import ( @@ -2118,6 +2119,12 @@ class TestRegression(TestCase): assert_raises(ValueError, np.frompyfunc, passer, 32, 1) + def test_repeat_broadcasting(self): + # gh-5743 + a = np.arange(60).reshape(3, 4, 5) + for axis in chain(range(-a.ndim, a.ndim), [None]): + assert_equal(a.repeat(2, axis=axis), a.repeat([2], axis=axis)) + if __name__ == "__main__": run_module_suite() diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index 3ba3beff9..8b6816958 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -11,6 +11,9 @@ types = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, np.single, np.double, np.longdouble, np.csingle, np.cdouble, np.clongdouble] +floating_types = np.floating.__subclasses__() + + # This compares scalarmath against ufuncs. class TestTypes(TestCase): @@ -284,5 +287,26 @@ class TestSizeOf(TestCase): assert_raises(TypeError, d.__sizeof__, "a") +class TestAbs(TestCase): + + def _test_abs_func(self, absfunc): + for tp in floating_types: + x = tp(-1.5) + assert_equal(absfunc(x), 1.5) + x = tp(0.0) + res = absfunc(x) + # assert_equal() checks zero signedness + assert_equal(res, 0.0) + x = tp(-0.0) + res = absfunc(x) + assert_equal(res, 0.0) + + def test_builtin_abs(self): + self._test_abs_func(abs) + + def test_numpy_abs(self): + self._test_abs_func(np.abs) + + if __name__ == "__main__": run_module_suite() diff --git a/numpy/distutils/fcompiler/gnu.py b/numpy/distutils/fcompiler/gnu.py index b61574dba..f568135c0 100644 --- a/numpy/distutils/fcompiler/gnu.py +++ b/numpy/distutils/fcompiler/gnu.py @@ -40,22 +40,41 @@ class GnuFCompiler(FCompiler): while version_string.startswith('gfortran: warning'): version_string = version_string[version_string.find('\n')+1:] - # Try to find a valid version string - m = re.search(r'([0-9.]+)', version_string) - if m: - # g77 provides a longer version string that starts with GNU - # Fortran - if version_string.startswith('GNU Fortran'): - return ('g77', m.group(1)) - - # gfortran only outputs a version string such as #.#.#, so check - # if the match is at the start of the string - elif m.start() == 0: + # Gfortran versions from after 2010 will output a simple string + # (usually "x.y", "x.y.z" or "x.y.z-q") for ``-dumpversion``; older + # gfortrans may still return long version strings (``-dumpversion`` was + # an alias for ``--version``) + if len(version_string) <= 20: + # Try to find a valid version string + m = re.search(r'([0-9.]+)', version_string) + if m: + # g77 provides a longer version string that starts with GNU + # Fortran + if version_string.startswith('GNU Fortran'): + return ('g77', m.group(1)) + + # gfortran only outputs a version string such as #.#.#, so check + # if the match is at the start of the string + elif m.start() == 0: + return ('gfortran', m.group(1)) + else: + # Output probably from --version, try harder: + m = re.search(r'GNU Fortran\s+95.*?([0-9-.]+)', version_string) + if m: return ('gfortran', m.group(1)) + m = re.search(r'GNU Fortran.*?\-?([0-9-.]+)', version_string) + if m: + v = m.group(1) + if v.startswith('0') or v.startswith('2') or v.startswith('3'): + # the '0' is for early g77's + return ('g77', v) + else: + # at some point in the 4.x series, the ' 95' was dropped + # from the version string + return ('gfortran', v) - # If these checks fail, then raise an error to make the problem easy - # to find. - err = 'A valid Fortran verison was not found in this string:\n' + # If still nothing, raise an error to make the problem easy to find. + err = 'A valid Fortran version was not found in this string:\n' raise ValueError(err + version_string) def version_match(self, version_string): diff --git a/numpy/distutils/tests/test_fcompiler_gnu.py b/numpy/distutils/tests/test_fcompiler_gnu.py index f7a124c50..7ca99db22 100644 --- a/numpy/distutils/tests/test_fcompiler_gnu.py +++ b/numpy/distutils/tests/test_fcompiler_gnu.py @@ -1,6 +1,6 @@ from __future__ import division, absolute_import, print_function -from numpy.testing import * +from numpy.testing import TestCase, assert_, run_module_suite import numpy.distutils.fcompiler @@ -14,6 +14,12 @@ g77_version_strings = [ ] gfortran_version_strings = [ + ('GNU Fortran 95 (GCC 4.0.3 20051023 (prerelease) (Debian 4.0.2-3))', + '4.0.3'), + ('GNU Fortran 95 (GCC) 4.1.0', '4.1.0'), + ('GNU Fortran 95 (GCC) 4.2.0 20060218 (experimental)', '4.2.0'), + ('GNU Fortran (GCC) 4.3.0 20070316 (experimental)', '4.3.0'), + ('GNU Fortran (rubenvb-4.8.0) 4.8.0', '4.8.0'), ('4.8.0', '4.8.0'), ('4.0.3-7', '4.0.3'), ("gfortran: warning: couldn't understand kern.osversion '14.1.0\n4.9.1", diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 1a2133aa9..66a1b356c 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -314,21 +314,19 @@ def _write_array_header(fp, d, version=None): header = header + ' '*topad + '\n' header = asbytes(_filter_header(header)) - if len(header) >= (256*256) and version == (1, 0): - raise ValueError("header does not fit inside %s bytes required by the" - " 1.0 format" % (256*256)) - if len(header) < (256*256): - header_len_str = struct.pack('<H', len(header)) + hlen = len(header) + if hlen < 256*256 and version in (None, (1, 0)): version = (1, 0) - elif len(header) < (2**32): - header_len_str = struct.pack('<I', len(header)) + header_prefix = magic(1, 0) + struct.pack('<H', hlen) + elif hlen < 2**32 and version in (None, (2, 0)): version = (2, 0) + header_prefix = magic(2, 0) + struct.pack('<I', hlen) else: - raise ValueError("header does not fit inside 4 GiB required by " - "the 2.0 format") + msg = "Header length %s too big for version=%s" + msg %= (hlen, version) + raise ValueError(msg) - fp.write(magic(*version)) - fp.write(header_len_str) + fp.write(header_prefix) fp.write(header) return version @@ -389,7 +387,7 @@ def read_array_header_1_0(fp): If the data is invalid. """ - _read_array_header(fp, version=(1, 0)) + return _read_array_header(fp, version=(1, 0)) def read_array_header_2_0(fp): """ @@ -422,7 +420,7 @@ def read_array_header_2_0(fp): If the data is invalid. """ - _read_array_header(fp, version=(2, 0)) + return _read_array_header(fp, version=(2, 0)) def _filter_header(s): @@ -517,7 +515,7 @@ def _read_array_header(fp, version): return d['shape'], d['fortran_order'], dtype -def write_array(fp, array, version=None, pickle_kwargs=None): +def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None): """ Write an array to an NPY file, including a header. @@ -535,6 +533,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): version : (int, int) or None, optional The version number of the format. None means use the oldest supported version that is able to store the data. Default: None + allow_pickle : bool, optional + Whether to allow writing pickled data. Default: True pickle_kwargs : dict, optional Additional keyword arguments to pass to pickle.dump, excluding 'protocol'. These are only useful when pickling objects in object @@ -543,7 +543,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): Raises ------ ValueError - If the array cannot be persisted. + If the array cannot be persisted. This includes the case of + allow_pickle=False and array being an object array. Various other errors If the array contains Python objects as part of its dtype, the process of pickling them may raise various errors if the objects @@ -565,6 +566,9 @@ def write_array(fp, array, version=None, pickle_kwargs=None): # We contain Python objects so we cannot write out the data # directly. Instead, we will pickle it out with version 2 of the # pickle protocol. + if not allow_pickle: + raise ValueError("Object arrays cannot be saved when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} pickle.dump(array, fp, protocol=2, **pickle_kwargs) @@ -586,7 +590,7 @@ def write_array(fp, array, version=None, pickle_kwargs=None): fp.write(chunk.tobytes('C')) -def read_array(fp, pickle_kwargs=None): +def read_array(fp, allow_pickle=True, pickle_kwargs=None): """ Read an array from an NPY file. @@ -595,6 +599,8 @@ def read_array(fp, pickle_kwargs=None): fp : file_like object If this is not a real file object, then this may take extra memory and time. + allow_pickle : bool, optional + Whether to allow reading pickled data. Default: True pickle_kwargs : dict Additional keyword arguments to pass to pickle.load. These are only useful when loading object arrays saved on Python 2 when using @@ -608,7 +614,8 @@ def read_array(fp, pickle_kwargs=None): Raises ------ ValueError - If the data is invalid. + If the data is invalid, or allow_pickle=False and the file contains + an object array. """ version = read_magic(fp) @@ -622,6 +629,9 @@ def read_array(fp, pickle_kwargs=None): # Now read the actual data. if dtype.hasobject: # The array contained Python objects. We need to unpickle the data. + if not allow_pickle: + raise ValueError("Object arrays cannot be loaded when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} try: diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 9aec98cc8..d22e8c047 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -906,9 +906,9 @@ def gradient(f, *varargs, **kwargs): Returns ------- - gradient : ndarray - N arrays of the same shape as `f` giving the derivative of `f` with - respect to each dimension. + gradient : list of ndarray + Each element of `list` has the same shape as `f` giving the derivative + of `f` with respect to each dimension. Examples -------- @@ -918,6 +918,10 @@ def gradient(f, *varargs, **kwargs): >>> np.gradient(x, 2) array([ 0.5 , 0.75, 1.25, 1.75, 2.25, 2.5 ]) + For two dimensional arrays, the return will be two arrays ordered by + axis. In this example the first array stands for the gradient in + rows and the second one in columns direction: + >>> np.gradient(np.array([[1, 2, 6], [3, 4, 5]], dtype=np.float)) [array([[ 2., 2., -1.], [ 2., 2., -1.]]), array([[ 1. , 2.5, 4. ], @@ -3735,6 +3739,7 @@ def insert(arr, obj, values, axis=None): [3, 5, 3]]) Difference between sequence and scalars: + >>> np.insert(a, [1], [[1],[2],[3]], axis=1) array([[1, 1, 1], [2, 2, 2], diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index b56d7d5a9..ec89397a0 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -164,6 +164,8 @@ class NpzFile(object): f : BagObj instance An object on which attribute can be performed as an alternative to getitem access on the `NpzFile` instance itself. + allow_pickle : bool, optional + Allow loading pickled data. Default: True pickle_kwargs : dict, optional Additional keyword arguments to pass on to pickle.load. These are only useful when loading object arrays saved on @@ -199,12 +201,14 @@ class NpzFile(object): """ - def __init__(self, fid, own_fid=False, pickle_kwargs=None): + def __init__(self, fid, own_fid=False, allow_pickle=True, + pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an # optional component of the so-called standard library. _zip = zipfile_factory(fid) self._files = _zip.namelist() self.files = [] + self.allow_pickle = allow_pickle self.pickle_kwargs = pickle_kwargs for x in self._files: if x.endswith('.npy'): @@ -262,6 +266,7 @@ class NpzFile(object): if magic == format.MAGIC_PREFIX: bytes = self.zip.open(key) return format.read_array(bytes, + allow_pickle=self.allow_pickle, pickle_kwargs=self.pickle_kwargs) else: return self.zip.read(key) @@ -295,7 +300,8 @@ class NpzFile(object): return self.files.__contains__(key) -def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): +def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, + encoding='ASCII'): """ Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files. @@ -312,6 +318,12 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): and sliced like any ndarray. Memory mapping is especially useful for accessing small fragments of large files without reading the entire file into memory. + allow_pickle : bool, optional + Allow loading pickled object arrays stored in npy files. Reasons for + disallowing pickles include security, as loading pickled data can + execute arbitrary code. If pickles are disallowed, loading object + arrays will fail. + Default: True fix_imports : bool, optional Only useful when loading Python 2 generated pickled files on Python 3, which includes npy/npz files containing object arrays. If `fix_imports` @@ -324,7 +336,6 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical data. Default: 'ASCII' - Returns ------- result : array, tuple, dict, etc. @@ -335,6 +346,8 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): ------ IOError If the input file does not exist or cannot be read. + ValueError + The file contains an object array, but allow_pickle=False given. See Also -------- @@ -430,15 +443,20 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): # Transfer file ownership to NpzFile tmp = own_fid own_fid = False - return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs) + return NpzFile(fid, own_fid=tmp, allow_pickle=allow_pickle, + pickle_kwargs=pickle_kwargs) elif magic == format.MAGIC_PREFIX: # .npy file if mmap_mode: return format.open_memmap(file, mode=mmap_mode) else: - return format.read_array(fid, pickle_kwargs=pickle_kwargs) + return format.read_array(fid, allow_pickle=allow_pickle, + pickle_kwargs=pickle_kwargs) else: # Try a pickle + if not allow_pickle: + raise ValueError("allow_pickle=False, but file does not contain " + "non-pickled data") try: return pickle.load(fid, **pickle_kwargs) except: @@ -449,7 +467,7 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): fid.close() -def save(file, arr, fix_imports=True): +def save(file, arr, allow_pickle=True, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -460,6 +478,14 @@ def save(file, arr, fix_imports=True): then the filename is unchanged. If file is a string, a ``.npy`` extension will be appended to the file name if it does not already have one. + allow_pickle : bool, optional + Allow saving object arrays using Python pickles. Reasons for disallowing + pickles include security (loading pickled data can execute arbitrary + code) and portability (pickled objects may not be loadable on different + Python installations, for example if the stored objects require libraries + that are not available, and not all pickled data is compatible between + Python 2 and Python 3). + Default: True fix_imports : bool, optional Only useful in forcing objects in object arrays on Python 3 to be pickled in a Python 2 compatible way. If `fix_imports` is True, pickle @@ -509,7 +535,8 @@ def save(file, arr, fix_imports=True): try: arr = np.asanyarray(arr) - format.write_array(fid, arr, pickle_kwargs=pickle_kwargs) + format.write_array(fid, arr, allow_pickle=allow_pickle, + pickle_kwargs=pickle_kwargs) finally: if own_fid: fid.close() @@ -621,7 +648,7 @@ def savez_compressed(file, *args, **kwds): _savez(file, args, kwds, True) -def _savez(file, args, kwds, compress, pickle_kwargs=None): +def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an optional # component of the so-called standard library. import zipfile @@ -656,6 +683,7 @@ def _savez(file, args, kwds, compress, pickle_kwargs=None): fid = open(tmpfile, 'wb') try: format.write_array(fid, np.asanyarray(val), + allow_pickle=allow_pickle, pickle_kwargs=pickle_kwargs) fid.close() fid = None diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index 169f01182..4f8a65148 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -599,6 +599,22 @@ def test_pickle_python2_python3(): encoding='latin1', fix_imports=False) +def test_pickle_disallow(): + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + path = os.path.join(data_dir, 'py2-objarr.npy') + assert_raises(ValueError, np.load, path, + allow_pickle=False, encoding='latin1') + + path = os.path.join(data_dir, 'py2-objarr.npz') + f = np.load(path, allow_pickle=False, encoding='latin1') + assert_raises(ValueError, f.__getitem__, 'x') + + path = os.path.join(tempdir, 'pickle-disabled.npy') + assert_raises(ValueError, np.save, path, np.array([None], dtype=object), + allow_pickle=False) + + def test_version_2_0(): f = BytesIO() # requires more than 2 byte for header @@ -694,6 +710,26 @@ malformed_magic = asbytes_nested([ '', ]) +def test_read_magic(): + s1 = BytesIO() + s2 = BytesIO() + + arr = np.ones((3, 6), dtype=float) + + format.write_array(s1, arr, version=(1, 0)) + format.write_array(s2, arr, version=(2, 0)) + + s1.seek(0) + s2.seek(0) + + version1 = format.read_magic(s1) + version2 = format.read_magic(s2) + + assert_(version1 == (1, 0)) + assert_(version2 == (2, 0)) + + assert_(s1.tell() == format.MAGIC_LEN) + assert_(s2.tell() == format.MAGIC_LEN) def test_read_magic_bad_magic(): for magic in malformed_magic: @@ -724,6 +760,30 @@ def test_large_header(): assert_raises(ValueError, format.write_array_header_1_0, s, d) +def test_read_array_header_1_0(): + s = BytesIO() + + arr = np.ones((3, 6), dtype=float) + format.write_array(s, arr, version=(1, 0)) + + s.seek(format.MAGIC_LEN) + shape, fortran, dtype = format.read_array_header_1_0(s) + + assert_((shape, fortran, dtype) == ((3, 6), False, float)) + + +def test_read_array_header_2_0(): + s = BytesIO() + + arr = np.ones((3, 6), dtype=float) + format.write_array(s, arr, version=(2, 0)) + + s.seek(format.MAGIC_LEN) + shape, fortran, dtype = format.read_array_header_2_0(s) + + assert_((shape, fortran, dtype) == ((3, 6), False, float)) + + def test_bad_header(): # header of length less than 2 should fail s = BytesIO() diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py index 3931f95e5..7afd1206c 100644 --- a/numpy/lib/tests/test_type_check.py +++ b/numpy/lib/tests/test_type_check.py @@ -277,6 +277,8 @@ class TestNanToNum(TestCase): def test_integer(self): vals = nan_to_num(1) assert_all(vals == 1) + vals = nan_to_num([1]) + assert_array_equal(vals, np.array([1], np.int)) def test_complex_good(self): vals = nan_to_num(1+1j) diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index a45d0bd86..99677b394 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -324,12 +324,13 @@ def nan_to_num(x): Returns ------- - out : ndarray, float - Array with the same shape as `x` and dtype of the element in `x` with - the greatest precision. NaN is replaced by zero, and infinity - (-infinity) is replaced by the largest (smallest or most negative) - floating point value that fits in the output dtype. All finite numbers - are upcast to the output dtype (default float64). + out : ndarray + New Array with the same shape as `x` and dtype of the element in + `x` with the greatest precision. If `x` is inexact, then NaN is + replaced by zero, and infinity (-infinity) is replaced by the + largest (smallest or most negative) floating point value that fits + in the output dtype. If `x` is not inexact, then a copy of `x` is + returned. See Also -------- @@ -354,33 +355,22 @@ def nan_to_num(x): -1.28000000e+002, 1.28000000e+002]) """ - try: - t = x.dtype.type - except AttributeError: - t = obj2sctype(type(x)) - if issubclass(t, _nx.complexfloating): - return nan_to_num(x.real) + 1j * nan_to_num(x.imag) - else: - try: - y = x.copy() - except AttributeError: - y = array(x) - if not issubclass(t, _nx.integer): - if not y.shape: - y = array([x]) - scalar = True - else: - scalar = False - are_inf = isposinf(y) - are_neg_inf = isneginf(y) - are_nan = isnan(y) - maxf, minf = _getmaxmin(y.dtype.type) - y[are_nan] = 0 - y[are_inf] = maxf - y[are_neg_inf] = minf - if scalar: - y = y[0] - return y + x = _nx.array(x, subok=True) + xtype = x.dtype.type + if not issubclass(xtype, _nx.inexact): + return x + + iscomplex = issubclass(xtype, _nx.complexfloating) + isscalar = (x.ndim == 0) + + x = x[None] if isscalar else x + dest = (x.real, x.imag) if iscomplex else (x,) + maxf, minf = _getmaxmin(x.real.dtype) + for d in dest: + _nx.copyto(d, 0.0, where=isnan(d)) + _nx.copyto(d, maxf, where=isposinf(d)) + _nx.copyto(d, minf, where=isneginf(d)) + return x[0] if isscalar else x #----------------------------------------------------------------------------- |