diff options
author | Ben Root <ben.v.root@gmail.com> | 2011-09-15 19:37:48 -0500 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-10-23 10:31:58 -0600 |
commit | 6fc0737d623f3065eb6fe720ce13cf1ef07cfdfe (patch) | |
tree | 3188fa174e417c1e882afee4365c497af38a7cde /numpy | |
parent | 31c29026bdde0735aceeebeb2e050f0c52fb1146 (diff) | |
download | numpy-6fc0737d623f3065eb6fe720ce13cf1ef07cfdfe.tar.gz |
ENH: Explicitly coded argmin for timedelta
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/include/numpy/ndarraytypes.h | 7 | ||||
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 136 | ||||
-rw-r--r-- | numpy/core/src/multiarray/calculation.c | 111 | ||||
-rw-r--r-- | numpy/core/src/multiarray/usertypes.c | 1 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 105 |
5 files changed, 341 insertions, 19 deletions
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h index b4046f940..352d21e20 100644 --- a/numpy/core/include/numpy/ndarraytypes.h +++ b/numpy/core/include/numpy/ndarraytypes.h @@ -524,6 +524,13 @@ typedef struct { PyArray_FastClipFunc *fastclip; PyArray_FastPutmaskFunc *fastputmask; PyArray_FastTakeFunc *fasttake; + + /* + * Function to select smallest + * Can be NULL + */ + PyArray_ArgFunc *argmin; + } PyArray_ArrFuncs; /* The item must be reference counted when it is inserted or extracted. */ diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index fd46d929a..3225965ef 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -2691,6 +2691,79 @@ static int /**end repeat**/ +/**begin repeat + * + * #fname = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, + * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, + * CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA# + * #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong, + * longlong, ulonglong, npy_half, float, double, longdouble, + * float, double, longdouble, datetime, timedelta# + * #isfloat = 0*11, 1*7, 0*2# + * #isnan = nop*11, npy_half_isnan, npy_isnan*6, nop*2# + * #le = _LESS_THAN_OR_EQUAL*11, npy_half_le, _LESS_THAN_OR_EQUAL*8# + * #iscomplex = 0*15, 1*3, 0*2# + * #incr = ip++*15, ip+=2*3, ip++*2# + */ +static int +@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip)) +{ + intp i; + @type@ mp = *ip; +#if @iscomplex@ + @type@ mp_im = ip[1]; +#endif + + *min_ind = 0; + +#if @isfloat@ + if (@isnan@(mp)) { + /* nan encountered; it's minimal */ + return 0; + } +#endif +#if @iscomplex@ + if (@isnan@(mp_im)) { + /* nan encountered; it's minimal */ + return 0; + } +#endif + + for (i = 1; i < n; i++) { + @incr@; + /* + * Propagate nans, similarly as max() and min() + */ +#if @iscomplex@ + /* Lexical order for complex numbers */ + if ((mp > ip[0]) || ((ip[0] == mp) && (mp_im > ip[1])) + || @isnan@(ip[0]) || @isnan@(ip[1])) { + mp = ip[0]; + mp_im = ip[1]; + *min_ind = i; + if (@isnan@(mp) || @isnan@(mp_im)) { + /* nan encountered, it's minimal */ + break; + } + } +#else + if (!@le@(mp, *ip)) { /* negated, for correct nan handling */ + mp = *ip; + *min_ind = i; +#if @isfloat@ + if (@isnan@(mp)) { + /* nan encountered, it's minimal */ + break; + } +#endif + } +#endif + } + return 0; +} + +/**end repeat**/ + #undef _LESS_THAN_OR_EQUAL static int @@ -2749,6 +2822,63 @@ static int #define VOID_argmax NULL +static int +OBJECT_argmin(PyObject **ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip)) +{ + intp i; + PyObject *mp = ip[0]; + + *min_ind = 0; + i = 1; + while (i < n && mp == NULL) { + mp = ip[i]; + i++; + } + for (; i < n; i++) { + ip++; +#if defined(NPY_PY3K) + if (*ip != NULL && PyObject_RichCompareBool(mp, *ip, Py_GT) == 1) { +#else + if (*ip != NULL && PyObject_Compare(mp, *ip) > 0) { +#endif + mp = *ip; + *min_ind = i; + } + } + return 0; +} + +/**begin repeat + * + * #fname = STRING, UNICODE# + * #type = char, PyArray_UCS4# + */ +static int +@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *aip) +{ + intp i; + int elsize = PyArray_DESCR(aip)->elsize; + @type@ *mp = (@type@ *)PyArray_malloc(elsize); + + if (mp==NULL) return 0; + memcpy(mp, ip, elsize); + *min_ind = 0; + for(i=1; i<n; i++) { + ip += elsize; + if (@fname@_compare(mp,ip,aip) > 0) { + memcpy(mp, ip, elsize); + *min_ind=i; + } + } + PyArray_free(mp); + return 0; +} + +/**end repeat**/ + + +#define VOID_argmin NULL + /* ***************************************************************************** @@ -3410,7 +3540,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = { NULL, (PyArray_FastClipFunc *)NULL, (PyArray_FastPutmaskFunc *)NULL, - (PyArray_FastTakeFunc *)NULL + (PyArray_FastTakeFunc *)NULL, + (PyArray_ArgFunc*)@from@_argmin }; /* @@ -3518,7 +3649,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = { NULL, (PyArray_FastClipFunc*)@from@_fastclip, (PyArray_FastPutmaskFunc*)@from@_fastputmask, - (PyArray_FastTakeFunc*)@from@_fasttake + (PyArray_FastTakeFunc*)@from@_fasttake, + (PyArray_ArgFunc*)@from@_argmin }; /* diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c index 46108ac03..4f2db096c 100644 --- a/numpy/core/src/multiarray/calculation.c +++ b/numpy/core/src/multiarray/calculation.c @@ -146,32 +146,109 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out) * ArgMin */ NPY_NO_EXPORT PyObject * -PyArray_ArgMin(PyArrayObject *ap, int axis, PyArrayObject *out) +PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out) { - PyObject *obj, *new, *ret; + PyArrayObject *ap = NULL, *rp = NULL; + PyArray_ArgFunc* arg_func; + char *ip; + intp *rptr; + intp i, n, m; + int elsize; + NPY_BEGIN_THREADS_DEF; - if (PyArray_ISFLEXIBLE(ap)) { - PyErr_SetString(PyExc_TypeError, - "argmin is unsupported for this type"); + if ((ap=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) { return NULL; } - else if (PyArray_ISUNSIGNED(ap)) { - obj = PyInt_FromLong((long) -1); - } - else if (PyArray_TYPE(ap) == PyArray_BOOL) { - obj = PyInt_FromLong((long) 1); + /* + * We need to permute the array so that axis is placed at the end. + * And all other dimensions are shifted left. + */ + if (axis != PyArray_NDIM(ap)-1) { + PyArray_Dims newaxes; + intp dims[MAX_DIMS]; + int i; + + newaxes.ptr = dims; + newaxes.len = PyArray_NDIM(ap); + for (i = 0; i < axis; i++) dims[i] = i; + for (i = axis; i < PyArray_NDIM(ap) - 1; i++) dims[i] = i + 1; + dims[PyArray_NDIM(ap) - 1] = axis; + op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes); + Py_DECREF(ap); + if (op == NULL) { + return NULL; + } } else { - obj = PyInt_FromLong((long) 0); + op = ap; } - new = PyArray_EnsureAnyArray(PyNumber_Subtract(obj, (PyObject *)ap)); - Py_DECREF(obj); - if (new == NULL) { + + /* Will get native-byte order contiguous copy. */ + ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op, + PyArray_DESCR(op)->type_num, 1, 0); + Py_DECREF(op); + if (ap == NULL) { return NULL; } - ret = PyArray_ArgMax((PyArrayObject *)new, axis, out); - Py_DECREF(new); - return ret; + arg_func = PyArray_DESCR(ap)->f->argmin; + if (arg_func == NULL) { + PyErr_SetString(PyExc_TypeError, "data type not ordered"); + goto fail; + } + elsize = PyArray_DESCR(ap)->elsize; + m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1]; + if (m == 0) { + PyErr_SetString(PyExc_ValueError, + "attempt to get argmax/argmin "\ + "of an empty sequence"); + goto fail; + } + + if (!out) { + rp = (PyArrayObject *)PyArray_New(Py_TYPE(ap), PyArray_NDIM(ap)-1, + PyArray_DIMS(ap), PyArray_INTP, + NULL, NULL, 0, 0, + (PyObject *)ap); + if (rp == NULL) { + goto fail; + } + } + else { + if (PyArray_SIZE(out) != + PyArray_MultiplyList(PyArray_DIMS(ap), PyArray_NDIM(ap) - 1)) { + PyErr_SetString(PyExc_TypeError, + "invalid shape for output array."); + } + rp = (PyArrayObject *)PyArray_FromArray(out, + PyArray_DescrFromType(PyArray_INTP), + NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY); + if (rp == NULL) { + goto fail; + } + } + + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap)); + n = PyArray_SIZE(ap)/m; + rptr = (intp *)PyArray_DATA(rp); + for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) { + arg_func(ip, m, rptr, ap); + rptr += 1; + } + NPY_END_THREADS_DESCR(PyArray_DESCR(ap)); + + Py_DECREF(ap); + /* Trigger the UPDATEIFCOPY if necessary */ + if (out != NULL && out != rp) { + Py_DECREF(rp); + rp = out; + Py_INCREF(rp); + } + return (PyObject *)rp; + + fail: + Py_DECREF(ap); + Py_XDECREF(rp); + return NULL; } /*NUMPY_API diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c index 53acff42d..61df37b16 100644 --- a/numpy/core/src/multiarray/usertypes.c +++ b/numpy/core/src/multiarray/usertypes.c @@ -104,6 +104,7 @@ PyArray_InitArrFuncs(PyArray_ArrFuncs *f) f->copyswap = NULL; f->compare = NULL; f->argmax = NULL; + f->argmin = NULL; f->dotfunc = NULL; f->scanfunc = NULL; f->fromstr = NULL; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index e478de7e0..88d52568c 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -7,6 +7,9 @@ from nose import SkipTest from numpy.core import * from numpy.core.multiarray_tests import test_neighborhood_iterator, test_neighborhood_iterator_oob +# Need to test an object that does not fully implement math interface +from datetime import timedelta + from numpy.compat import asbytes, getexception, strchar from test_print import in_foreign_locale @@ -925,6 +928,38 @@ class TestArgmax(TestCase): ([complex(0, 0), complex(0, 2), complex(0, 1)], 1), ([complex(1, 0), complex(0, 2), complex(0, 1)], 0), ([complex(1, 0), complex(0, 2), complex(1, 1)], 2), + + # Fails on 32-bit systems (haven't tested 64-bit) due to y2.038k bug + #([np.datetime64('1923-04-14T12:43:12'), + # np.datetime64('1994-06-21T14:43:15'), + # np.datetime64('2001-10-15T04:10:32'), + # np.datetime64('1995-11-25T16:02:16'), + # np.datetime64('2005-01-04T03:14:12'), + # np.datetime64('2041-12-03T14:05:03')], 5), + ([np.datetime64('1935-09-14T04:40:11'), + np.datetime64('1949-10-12T12:32:11'), + np.datetime64('2010-01-03T05:14:12'), + np.datetime64('2015-11-20T12:20:59'), + np.datetime64('1932-09-23T10:10:13'), + np.datetime64('2014-10-10T03:50:30')], 3), + #([np.datetime64('2059-03-14T12:43:12'), + # np.datetime64('1996-09-21T14:43:15'), + # np.datetime64('2001-10-15T04:10:32'), + # np.datetime64('2022-12-25T16:02:16'), + # np.datetime64('1963-10-04T03:14:12'), + # np.datetime64('2013-05-08T18:15:23')], 0), + + ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35), + timedelta(days=-1, seconds=23)], 0), + ([timedelta(days=1, seconds=43), timedelta(days=10, seconds=5), + timedelta(days=5, seconds=14)], 1), + ([timedelta(days=10, seconds=24), timedelta(days=10, seconds=5), + timedelta(days=10, seconds=43)], 2), + + # Can't reduce a "flexible type" + #(['a', 'z', 'aa', 'zz'], 3), + #(['zz', 'a', 'aa', 'a'], 0), + #(['aa', 'z', 'zz', 'a'], 2), ] def test_all(self): @@ -942,6 +977,76 @@ class TestArgmax(TestCase): assert_equal(arr[np.argmax(arr)], np.max(arr), err_msg="%r"%arr) +class TestArgmin(TestCase): + + nan_arr = [ + ([0, 1, 2, 3, np.nan], 4), + ([0, 1, 2, np.nan, 3], 3), + ([np.nan, 0, 1, 2, 3], 0), + ([np.nan, 0, np.nan, 2, 3], 0), + ([0, 1, 2, 3, complex(0,np.nan)], 4), + ([0, 1, 2, 3, complex(np.nan,0)], 4), + ([0, 1, 2, complex(np.nan,0), 3], 3), + ([0, 1, 2, complex(0,np.nan), 3], 3), + ([complex(0,np.nan), 0, 1, 2, 3], 0), + ([complex(np.nan, np.nan), 0, 1, 2, 3], 0), + ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0), + ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0), + ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0), + + ([complex(0, 0), complex(0, 2), complex(0, 1)], 0), + ([complex(1, 0), complex(0, 2), complex(0, 1)], 2), + ([complex(1, 0), complex(0, 2), complex(1, 1)], 1), + + # Fails on 32-bit systems (haven't tested 64-bit) due to y2.038k bug + #([np.datetime64('1923-04-14T12:43:12'), + # np.datetime64('1994-06-21T14:43:15'), + # np.datetime64('2001-10-15T04:10:32'), + # np.datetime64('1995-11-25T16:02:16'), + # np.datetime64('2005-01-04T03:14:12'), + # np.datetime64('2041-12-03T14:05:03')], 0), + ([np.datetime64('1935-09-14T04:40:11'), + np.datetime64('1949-10-12T12:32:11'), + np.datetime64('2010-01-03T05:14:12'), + np.datetime64('2014-11-20T12:20:59'), + np.datetime64('2015-09-23T10:10:13'), + np.datetime64('1932-10-10T03:50:30')], 5), + #([np.datetime64('2059-03-14T12:43:12'), + # np.datetime64('1996-09-21T14:43:15'), + # np.datetime64('2001-10-15T04:10:32'), + # np.datetime64('2022-12-25T16:02:16'), + # np.datetime64('1963-10-04T03:14:12'), + # np.datetime64('2013-05-08T18:15:23')], 4), + + ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35), + timedelta(days=-1, seconds=23)], 2), + ([timedelta(days=1, seconds=43), timedelta(days=10, seconds=5), + timedelta(days=5, seconds=14)], 0), + ([timedelta(days=10, seconds=24), timedelta(days=10, seconds=5), + timedelta(days=10, seconds=43)], 1), + + # Can't reduce a "flexible type" + #(['a', 'z', 'aa', 'zz'], 0), + #(['zz', 'a', 'aa', 'a'], 1), + #(['aa', 'z', 'zz', 'a'], 3), + ] + + def test_all(self): + a = np.random.normal(0,1,(4,5,6,7,8)) + for i in xrange(a.ndim): + amin = a.min(i) + aargmin = a.argmin(i) + axes = range(a.ndim) + axes.remove(i) + assert_(all(amin == aargmin.choose(*a.transpose(i,*axes)))) + + def test_combinations(self): + for arr, pos in self.nan_arr: + assert_equal(np.argmin(arr), pos, err_msg="%r"%arr) + assert_equal(arr[np.argmin(arr)], np.min(arr), err_msg="%r"%arr) + + + class TestMinMax(TestCase): def test_scalar(self): assert_raises(ValueError, np.amax, 1, 1) |