summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2019-01-21 08:39:49 +0200
committerGitHub <noreply@github.com>2019-01-21 08:39:49 +0200
commit3cbc11ac56054ad3ac7461e57433aefe37f2e3e4 (patch)
tree59d0b11caf42d5a55538ba69a00c9171cb1064b1 /numpy/core/src
parent74f3d07ab0bcfd63f42f62b26eeb6ce68efd4f21 (diff)
parent4b81a240a4ffffea8a502afbdea43d8bf2991228 (diff)
downloadnumpy-3cbc11ac56054ad3ac7461e57433aefe37f2e3e4.tar.gz
Merge pull request #12683 from tylerjereddy/timedelta64_divmod
ENH: add mm->qm divmod
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/umath/loops.c.src41
-rw-r--r--numpy/core/src/umath/loops.h.src3
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c49
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.h7
4 files changed, 95 insertions, 5 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 6accf3065..5267be261 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1623,7 +1623,7 @@ TIMEDELTA_mm_m_remainder(char **args, npy_intp *dimensions, npy_intp *steps, voi
else {
if (in2 == 0) {
npy_set_floatstatus_divbyzero();
- *((npy_timedelta *)op1) = 0;
+ *((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
else {
/* handle mixed case the way Python does */
@@ -1647,18 +1647,49 @@ TIMEDELTA_mm_q_floor_divide(char **args, npy_intp *dimensions, npy_intp *steps,
const npy_timedelta in2 = *(npy_timedelta *)ip2;
if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
npy_set_floatstatus_invalid();
- *((npy_timedelta *)op1) = 0;
+ *((npy_int64 *)op1) = 0;
}
else if (in2 == 0) {
npy_set_floatstatus_divbyzero();
- *((npy_timedelta *)op1) = 0;
+ *((npy_int64 *)op1) = 0;
}
else {
if (((in1 > 0) != (in2 > 0)) && (in1 % in2 != 0)) {
- *((npy_timedelta *)op1) = in1/in2 - 1;
+ *((npy_int64 *)op1) = in1/in2 - 1;
+ }
+ else {
+ *((npy_int64 *)op1) = in1/in2;
+ }
+ }
+ }
+}
+
+NPY_NO_EXPORT void
+TIMEDELTA_mm_qm_divmod(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
+{
+ BINARY_LOOP_TWO_OUT {
+ const npy_timedelta in1 = *(npy_timedelta *)ip1;
+ const npy_timedelta in2 = *(npy_timedelta *)ip2;
+ if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
+ npy_set_floatstatus_invalid();
+ *((npy_int64 *)op1) = 0;
+ *((npy_timedelta *)op2) = NPY_DATETIME_NAT;
+ }
+ else if (in2 == 0) {
+ npy_set_floatstatus_divbyzero();
+ *((npy_int64 *)op1) = 0;
+ *((npy_timedelta *)op2) = NPY_DATETIME_NAT;
+ }
+ else {
+ const npy_int64 quo = in1 / in2;
+ const npy_timedelta rem = in1 % in2;
+ if ((in1 > 0) == (in2 > 0) || rem == 0) {
+ *((npy_int64 *)op1) = quo;
+ *((npy_timedelta *)op2) = rem;
}
else {
- *((npy_timedelta *)op1) = in1/in2;
+ *((npy_int64 *)op1) = quo - 1;
+ *((npy_timedelta *)op2) = rem + in2;
}
}
}
diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src
index 3c908121e..5264a6533 100644
--- a/numpy/core/src/umath/loops.h.src
+++ b/numpy/core/src/umath/loops.h.src
@@ -479,6 +479,9 @@ TIMEDELTA_mm_q_floor_divide(char **args, npy_intp *dimensions, npy_intp *steps,
NPY_NO_EXPORT void
TIMEDELTA_mm_m_remainder(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
+NPY_NO_EXPORT void
+TIMEDELTA_mm_qm_divmod(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
+
/* Special case equivalents to above functions */
#define TIMEDELTA_mq_m_true_divide TIMEDELTA_mq_m_divide
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index e2f4d8005..a4a59faa9 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -2357,3 +2357,52 @@ type_tuple_type_resolver(PyUFuncObject *self,
return -1;
}
+
+NPY_NO_EXPORT int
+PyUFunc_DivmodTypeResolver(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes)
+{
+ int type_num1, type_num2;
+ int i;
+
+ type_num1 = PyArray_DESCR(operands[0])->type_num;
+ type_num2 = PyArray_DESCR(operands[1])->type_num;
+
+ /* Use the default when datetime and timedelta are not involved */
+ if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) {
+ return PyUFunc_DefaultTypeResolver(ufunc, casting, operands,
+ type_tup, out_dtypes);
+ }
+ if (type_num1 == NPY_TIMEDELTA) {
+ if (type_num2 == NPY_TIMEDELTA) {
+ out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
+ PyArray_DESCR(operands[1]));
+ out_dtypes[1] = out_dtypes[0];
+ Py_INCREF(out_dtypes[1]);
+ out_dtypes[2] = PyArray_DescrFromType(NPY_LONGLONG);
+ Py_INCREF(out_dtypes[2]);
+ out_dtypes[3] = out_dtypes[0];
+ Py_INCREF(out_dtypes[3]);
+ }
+ else {
+ return raise_binary_type_reso_error(ufunc, operands);
+ }
+ }
+ else {
+ return raise_binary_type_reso_error(ufunc, operands);
+ }
+
+ /* Check against the casting rules */
+ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
+ for (i = 0; i < 4; ++i) {
+ Py_DECREF(out_dtypes[i]);
+ out_dtypes[i] = NULL;
+ }
+ return -1;
+ }
+
+ return 0;
+}
diff --git a/numpy/core/src/umath/ufunc_type_resolution.h b/numpy/core/src/umath/ufunc_type_resolution.h
index 2f37af753..78313b1ef 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.h
+++ b/numpy/core/src/umath/ufunc_type_resolution.h
@@ -99,6 +99,13 @@ PyUFunc_RemainderTypeResolver(PyUFuncObject *ufunc,
PyObject *type_tup,
PyArray_Descr **out_dtypes);
+NPY_NO_EXPORT int
+PyUFunc_DivmodTypeResolver(PyUFuncObject *ufunc,
+ NPY_CASTING casting,
+ PyArrayObject **operands,
+ PyObject *type_tup,
+ PyArray_Descr **out_dtypes);
+
/*
* Does a linear search for the best inner loop of the ufunc.
*