summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBen Root <ben.v.root@gmail.com>2011-09-15 19:37:48 -0500
committerCharles Harris <charlesr.harris@gmail.com>2011-10-23 10:31:58 -0600
commit6fc0737d623f3065eb6fe720ce13cf1ef07cfdfe (patch)
tree3188fa174e417c1e882afee4365c497af38a7cde /numpy
parent31c29026bdde0735aceeebeb2e050f0c52fb1146 (diff)
downloadnumpy-6fc0737d623f3065eb6fe720ce13cf1ef07cfdfe.tar.gz
ENH: Explicitly coded argmin for timedelta
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/include/numpy/ndarraytypes.h7
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src136
-rw-r--r--numpy/core/src/multiarray/calculation.c111
-rw-r--r--numpy/core/src/multiarray/usertypes.c1
-rw-r--r--numpy/core/tests/test_multiarray.py105
5 files changed, 341 insertions, 19 deletions
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h
index b4046f940..352d21e20 100644
--- a/numpy/core/include/numpy/ndarraytypes.h
+++ b/numpy/core/include/numpy/ndarraytypes.h
@@ -524,6 +524,13 @@ typedef struct {
PyArray_FastClipFunc *fastclip;
PyArray_FastPutmaskFunc *fastputmask;
PyArray_FastTakeFunc *fasttake;
+
+ /*
+ * Function to select smallest
+ * Can be NULL
+ */
+ PyArray_ArgFunc *argmin;
+
} PyArray_ArrFuncs;
/* The item must be reference counted when it is inserted or extracted. */
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index fd46d929a..3225965ef 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -2691,6 +2691,79 @@ static int
/**end repeat**/
+/**begin repeat
+ *
+ * #fname = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
+ * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA#
+ * #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong,
+ * longlong, ulonglong, npy_half, float, double, longdouble,
+ * float, double, longdouble, datetime, timedelta#
+ * #isfloat = 0*11, 1*7, 0*2#
+ * #isnan = nop*11, npy_half_isnan, npy_isnan*6, nop*2#
+ * #le = _LESS_THAN_OR_EQUAL*11, npy_half_le, _LESS_THAN_OR_EQUAL*8#
+ * #iscomplex = 0*15, 1*3, 0*2#
+ * #incr = ip++*15, ip+=2*3, ip++*2#
+ */
+static int
+@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip))
+{
+ intp i;
+ @type@ mp = *ip;
+#if @iscomplex@
+ @type@ mp_im = ip[1];
+#endif
+
+ *min_ind = 0;
+
+#if @isfloat@
+ if (@isnan@(mp)) {
+ /* nan encountered; it's minimal */
+ return 0;
+ }
+#endif
+#if @iscomplex@
+ if (@isnan@(mp_im)) {
+ /* nan encountered; it's minimal */
+ return 0;
+ }
+#endif
+
+ for (i = 1; i < n; i++) {
+ @incr@;
+ /*
+ * Propagate nans, similarly as max() and min()
+ */
+#if @iscomplex@
+ /* Lexical order for complex numbers */
+ if ((mp > ip[0]) || ((ip[0] == mp) && (mp_im > ip[1]))
+ || @isnan@(ip[0]) || @isnan@(ip[1])) {
+ mp = ip[0];
+ mp_im = ip[1];
+ *min_ind = i;
+ if (@isnan@(mp) || @isnan@(mp_im)) {
+ /* nan encountered, it's minimal */
+ break;
+ }
+ }
+#else
+ if (!@le@(mp, *ip)) { /* negated, for correct nan handling */
+ mp = *ip;
+ *min_ind = i;
+#if @isfloat@
+ if (@isnan@(mp)) {
+ /* nan encountered, it's minimal */
+ break;
+ }
+#endif
+ }
+#endif
+ }
+ return 0;
+}
+
+/**end repeat**/
+
#undef _LESS_THAN_OR_EQUAL
static int
@@ -2749,6 +2822,63 @@ static int
#define VOID_argmax NULL
+static int
+OBJECT_argmin(PyObject **ip, intp n, intp *min_ind, PyArrayObject *NPY_UNUSED(aip))
+{
+ intp i;
+ PyObject *mp = ip[0];
+
+ *min_ind = 0;
+ i = 1;
+ while (i < n && mp == NULL) {
+ mp = ip[i];
+ i++;
+ }
+ for (; i < n; i++) {
+ ip++;
+#if defined(NPY_PY3K)
+ if (*ip != NULL && PyObject_RichCompareBool(mp, *ip, Py_GT) == 1) {
+#else
+ if (*ip != NULL && PyObject_Compare(mp, *ip) > 0) {
+#endif
+ mp = *ip;
+ *min_ind = i;
+ }
+ }
+ return 0;
+}
+
+/**begin repeat
+ *
+ * #fname = STRING, UNICODE#
+ * #type = char, PyArray_UCS4#
+ */
+static int
+@fname@_argmin(@type@ *ip, intp n, intp *min_ind, PyArrayObject *aip)
+{
+ intp i;
+ int elsize = PyArray_DESCR(aip)->elsize;
+ @type@ *mp = (@type@ *)PyArray_malloc(elsize);
+
+ if (mp==NULL) return 0;
+ memcpy(mp, ip, elsize);
+ *min_ind = 0;
+ for(i=1; i<n; i++) {
+ ip += elsize;
+ if (@fname@_compare(mp,ip,aip) > 0) {
+ memcpy(mp, ip, elsize);
+ *min_ind=i;
+ }
+ }
+ PyArray_free(mp);
+ return 0;
+}
+
+/**end repeat**/
+
+
+#define VOID_argmin NULL
+
/*
*****************************************************************************
@@ -3410,7 +3540,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
NULL,
(PyArray_FastClipFunc *)NULL,
(PyArray_FastPutmaskFunc *)NULL,
- (PyArray_FastTakeFunc *)NULL
+ (PyArray_FastTakeFunc *)NULL,
+ (PyArray_ArgFunc*)@from@_argmin
};
/*
@@ -3518,7 +3649,8 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
NULL,
(PyArray_FastClipFunc*)@from@_fastclip,
(PyArray_FastPutmaskFunc*)@from@_fastputmask,
- (PyArray_FastTakeFunc*)@from@_fasttake
+ (PyArray_FastTakeFunc*)@from@_fasttake,
+ (PyArray_ArgFunc*)@from@_argmin
};
/*
diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c
index 46108ac03..4f2db096c 100644
--- a/numpy/core/src/multiarray/calculation.c
+++ b/numpy/core/src/multiarray/calculation.c
@@ -146,32 +146,109 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
* ArgMin
*/
NPY_NO_EXPORT PyObject *
-PyArray_ArgMin(PyArrayObject *ap, int axis, PyArrayObject *out)
+PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
{
- PyObject *obj, *new, *ret;
+ PyArrayObject *ap = NULL, *rp = NULL;
+ PyArray_ArgFunc* arg_func;
+ char *ip;
+ intp *rptr;
+ intp i, n, m;
+ int elsize;
+ NPY_BEGIN_THREADS_DEF;
- if (PyArray_ISFLEXIBLE(ap)) {
- PyErr_SetString(PyExc_TypeError,
- "argmin is unsupported for this type");
+ if ((ap=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
return NULL;
}
- else if (PyArray_ISUNSIGNED(ap)) {
- obj = PyInt_FromLong((long) -1);
- }
- else if (PyArray_TYPE(ap) == PyArray_BOOL) {
- obj = PyInt_FromLong((long) 1);
+ /*
+ * We need to permute the array so that axis is placed at the end.
+ * And all other dimensions are shifted left.
+ */
+ if (axis != PyArray_NDIM(ap)-1) {
+ PyArray_Dims newaxes;
+ intp dims[MAX_DIMS];
+ int i;
+
+ newaxes.ptr = dims;
+ newaxes.len = PyArray_NDIM(ap);
+ for (i = 0; i < axis; i++) dims[i] = i;
+ for (i = axis; i < PyArray_NDIM(ap) - 1; i++) dims[i] = i + 1;
+ dims[PyArray_NDIM(ap) - 1] = axis;
+ op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
+ Py_DECREF(ap);
+ if (op == NULL) {
+ return NULL;
+ }
}
else {
- obj = PyInt_FromLong((long) 0);
+ op = ap;
}
- new = PyArray_EnsureAnyArray(PyNumber_Subtract(obj, (PyObject *)ap));
- Py_DECREF(obj);
- if (new == NULL) {
+
+ /* Will get native-byte order contiguous copy. */
+ ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op,
+ PyArray_DESCR(op)->type_num, 1, 0);
+ Py_DECREF(op);
+ if (ap == NULL) {
return NULL;
}
- ret = PyArray_ArgMax((PyArrayObject *)new, axis, out);
- Py_DECREF(new);
- return ret;
+ arg_func = PyArray_DESCR(ap)->f->argmin;
+ if (arg_func == NULL) {
+ PyErr_SetString(PyExc_TypeError, "data type not ordered");
+ goto fail;
+ }
+ elsize = PyArray_DESCR(ap)->elsize;
+ m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1];
+ if (m == 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "attempt to get argmax/argmin "\
+ "of an empty sequence");
+ goto fail;
+ }
+
+ if (!out) {
+ rp = (PyArrayObject *)PyArray_New(Py_TYPE(ap), PyArray_NDIM(ap)-1,
+ PyArray_DIMS(ap), PyArray_INTP,
+ NULL, NULL, 0, 0,
+ (PyObject *)ap);
+ if (rp == NULL) {
+ goto fail;
+ }
+ }
+ else {
+ if (PyArray_SIZE(out) !=
+ PyArray_MultiplyList(PyArray_DIMS(ap), PyArray_NDIM(ap) - 1)) {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid shape for output array.");
+ }
+ rp = (PyArrayObject *)PyArray_FromArray(out,
+ PyArray_DescrFromType(PyArray_INTP),
+ NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY);
+ if (rp == NULL) {
+ goto fail;
+ }
+ }
+
+ NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap));
+ n = PyArray_SIZE(ap)/m;
+ rptr = (intp *)PyArray_DATA(rp);
+ for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) {
+ arg_func(ip, m, rptr, ap);
+ rptr += 1;
+ }
+ NPY_END_THREADS_DESCR(PyArray_DESCR(ap));
+
+ Py_DECREF(ap);
+ /* Trigger the UPDATEIFCOPY if necessary */
+ if (out != NULL && out != rp) {
+ Py_DECREF(rp);
+ rp = out;
+ Py_INCREF(rp);
+ }
+ return (PyObject *)rp;
+
+ fail:
+ Py_DECREF(ap);
+ Py_XDECREF(rp);
+ return NULL;
}
/*NUMPY_API
diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c
index 53acff42d..61df37b16 100644
--- a/numpy/core/src/multiarray/usertypes.c
+++ b/numpy/core/src/multiarray/usertypes.c
@@ -104,6 +104,7 @@ PyArray_InitArrFuncs(PyArray_ArrFuncs *f)
f->copyswap = NULL;
f->compare = NULL;
f->argmax = NULL;
+ f->argmin = NULL;
f->dotfunc = NULL;
f->scanfunc = NULL;
f->fromstr = NULL;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index e478de7e0..88d52568c 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -7,6 +7,9 @@ from nose import SkipTest
from numpy.core import *
from numpy.core.multiarray_tests import test_neighborhood_iterator, test_neighborhood_iterator_oob
+# Need to test an object that does not fully implement math interface
+from datetime import timedelta
+
from numpy.compat import asbytes, getexception, strchar
from test_print import in_foreign_locale
@@ -925,6 +928,38 @@ class TestArgmax(TestCase):
([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
+
+ # Fails on 32-bit systems (haven't tested 64-bit) due to y2.038k bug
+ #([np.datetime64('1923-04-14T12:43:12'),
+ # np.datetime64('1994-06-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('1995-11-25T16:02:16'),
+ # np.datetime64('2005-01-04T03:14:12'),
+ # np.datetime64('2041-12-03T14:05:03')], 5),
+ ([np.datetime64('1935-09-14T04:40:11'),
+ np.datetime64('1949-10-12T12:32:11'),
+ np.datetime64('2010-01-03T05:14:12'),
+ np.datetime64('2015-11-20T12:20:59'),
+ np.datetime64('1932-09-23T10:10:13'),
+ np.datetime64('2014-10-10T03:50:30')], 3),
+ #([np.datetime64('2059-03-14T12:43:12'),
+ # np.datetime64('1996-09-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('2022-12-25T16:02:16'),
+ # np.datetime64('1963-10-04T03:14:12'),
+ # np.datetime64('2013-05-08T18:15:23')], 0),
+
+ ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35),
+ timedelta(days=-1, seconds=23)], 0),
+ ([timedelta(days=1, seconds=43), timedelta(days=10, seconds=5),
+ timedelta(days=5, seconds=14)], 1),
+ ([timedelta(days=10, seconds=24), timedelta(days=10, seconds=5),
+ timedelta(days=10, seconds=43)], 2),
+
+ # Can't reduce a "flexible type"
+ #(['a', 'z', 'aa', 'zz'], 3),
+ #(['zz', 'a', 'aa', 'a'], 0),
+ #(['aa', 'z', 'zz', 'a'], 2),
]
def test_all(self):
@@ -942,6 +977,76 @@ class TestArgmax(TestCase):
assert_equal(arr[np.argmax(arr)], np.max(arr), err_msg="%r"%arr)
+class TestArgmin(TestCase):
+
+ nan_arr = [
+ ([0, 1, 2, 3, np.nan], 4),
+ ([0, 1, 2, np.nan, 3], 3),
+ ([np.nan, 0, 1, 2, 3], 0),
+ ([np.nan, 0, np.nan, 2, 3], 0),
+ ([0, 1, 2, 3, complex(0,np.nan)], 4),
+ ([0, 1, 2, 3, complex(np.nan,0)], 4),
+ ([0, 1, 2, complex(np.nan,0), 3], 3),
+ ([0, 1, 2, complex(0,np.nan), 3], 3),
+ ([complex(0,np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
+
+ ([complex(0, 0), complex(0, 2), complex(0, 1)], 0),
+ ([complex(1, 0), complex(0, 2), complex(0, 1)], 2),
+ ([complex(1, 0), complex(0, 2), complex(1, 1)], 1),
+
+ # Fails on 32-bit systems (haven't tested 64-bit) due to y2.038k bug
+ #([np.datetime64('1923-04-14T12:43:12'),
+ # np.datetime64('1994-06-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('1995-11-25T16:02:16'),
+ # np.datetime64('2005-01-04T03:14:12'),
+ # np.datetime64('2041-12-03T14:05:03')], 0),
+ ([np.datetime64('1935-09-14T04:40:11'),
+ np.datetime64('1949-10-12T12:32:11'),
+ np.datetime64('2010-01-03T05:14:12'),
+ np.datetime64('2014-11-20T12:20:59'),
+ np.datetime64('2015-09-23T10:10:13'),
+ np.datetime64('1932-10-10T03:50:30')], 5),
+ #([np.datetime64('2059-03-14T12:43:12'),
+ # np.datetime64('1996-09-21T14:43:15'),
+ # np.datetime64('2001-10-15T04:10:32'),
+ # np.datetime64('2022-12-25T16:02:16'),
+ # np.datetime64('1963-10-04T03:14:12'),
+ # np.datetime64('2013-05-08T18:15:23')], 4),
+
+ ([timedelta(days=5, seconds=14), timedelta(days=2, seconds=35),
+ timedelta(days=-1, seconds=23)], 2),
+ ([timedelta(days=1, seconds=43), timedelta(days=10, seconds=5),
+ timedelta(days=5, seconds=14)], 0),
+ ([timedelta(days=10, seconds=24), timedelta(days=10, seconds=5),
+ timedelta(days=10, seconds=43)], 1),
+
+ # Can't reduce a "flexible type"
+ #(['a', 'z', 'aa', 'zz'], 0),
+ #(['zz', 'a', 'aa', 'a'], 1),
+ #(['aa', 'z', 'zz', 'a'], 3),
+ ]
+
+ def test_all(self):
+ a = np.random.normal(0,1,(4,5,6,7,8))
+ for i in xrange(a.ndim):
+ amin = a.min(i)
+ aargmin = a.argmin(i)
+ axes = range(a.ndim)
+ axes.remove(i)
+ assert_(all(amin == aargmin.choose(*a.transpose(i,*axes))))
+
+ def test_combinations(self):
+ for arr, pos in self.nan_arr:
+ assert_equal(np.argmin(arr), pos, err_msg="%r"%arr)
+ assert_equal(arr[np.argmin(arr)], np.min(arr), err_msg="%r"%arr)
+
+
+
class TestMinMax(TestCase):
def test_scalar(self):
assert_raises(ValueError, np.amax, 1, 1)