summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2019-04-22 19:04:52 -0600
committerGitHub <noreply@github.com>2019-04-22 19:04:52 -0600
commit6473fa20407ccd7581ac90cb3ef5b921a4a75cd7 (patch)
tree57a44c55676a82ee9d689e047751c40b009bde89 /numpy/core
parent2b59dcb273f00e7be13cdc32c5f396a55781c2f4 (diff)
parent55c7ed2c6822b5a5b30db7472c863dc1fa0c338f (diff)
downloadnumpy-6473fa20407ccd7581ac90cb3ef5b921a4a75cd7.tar.gz
Merge pull request #13371 from eric-wieser/__floor__-and-__ceil__
BUG/ENH: Make floor, ceil, and trunc call the matching special methods
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/code_generators/generate_umath.py6
-rw-r--r--numpy/core/src/umath/funcs.inc.src33
-rw-r--r--numpy/core/tests/test_umath.py37
3 files changed, 73 insertions, 3 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index 27a2ba1e1..e17523451 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -765,14 +765,14 @@ defdict = {
docstrings.get('numpy.core.umath.ceil'),
None,
TD(flts, f='ceil', astype={'e':'f'}),
- TD(P, f='ceil'),
+ TD(O, f='npy_ObjectCeil'),
),
'trunc':
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.trunc'),
None,
TD(flts, f='trunc', astype={'e':'f'}),
- TD(P, f='trunc'),
+ TD(O, f='npy_ObjectTrunc'),
),
'fabs':
Ufunc(1, 1, None,
@@ -786,7 +786,7 @@ defdict = {
docstrings.get('numpy.core.umath.floor'),
None,
TD(flts, f='floor', astype={'e':'f'}),
- TD(P, f='floor'),
+ TD(O, f='npy_ObjectFloor'),
),
'rint':
Ufunc(1, 1, None,
diff --git a/numpy/core/src/umath/funcs.inc.src b/numpy/core/src/umath/funcs.inc.src
index da2ab07f8..2acae3c37 100644
--- a/numpy/core/src/umath/funcs.inc.src
+++ b/numpy/core/src/umath/funcs.inc.src
@@ -160,6 +160,39 @@ npy_ObjectLogicalNot(PyObject *i1)
}
static PyObject *
+npy_ObjectFloor(PyObject *obj) {
+ PyObject *math_floor_func = NULL;
+
+ npy_cache_import("math", "floor", &math_floor_func);
+ if (math_floor_func == NULL) {
+ return NULL;
+ }
+ return PyObject_CallFunction(math_floor_func, "O", obj);
+}
+
+static PyObject *
+npy_ObjectCeil(PyObject *obj) {
+ PyObject *math_ceil_func = NULL;
+
+ npy_cache_import("math", "ceil", &math_ceil_func);
+ if (math_ceil_func == NULL) {
+ return NULL;
+ }
+ return PyObject_CallFunction(math_ceil_func, "O", obj);
+}
+
+static PyObject *
+npy_ObjectTrunc(PyObject *obj) {
+ PyObject *math_trunc_func = NULL;
+
+ npy_cache_import("math", "trunc", &math_trunc_func);
+ if (math_trunc_func == NULL) {
+ return NULL;
+ }
+ return PyObject_CallFunction(math_trunc_func, "O", obj);
+}
+
+static PyObject *
npy_ObjectGCD(PyObject *i1, PyObject *i2)
{
PyObject *gcd = NULL;
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 0eedd1e97..e0b0e11cf 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -5,6 +5,7 @@ import warnings
import fnmatch
import itertools
import pytest
+from fractions import Fraction
import numpy.core.umath as ncu
from numpy.core import _umath_tests as ncu_tests
@@ -2460,6 +2461,42 @@ class TestRationalFunctions(object):
assert_equal(np.gcd(2**100, 3**100), 1)
+class TestRoundingFunctions(object):
+
+ def test_object_direct(self):
+ """ test direct implementation of these magic methods """
+ class C:
+ def __floor__(self):
+ return 1
+ def __ceil__(self):
+ return 2
+ def __trunc__(self):
+ return 3
+
+ arr = np.array([C(), C()])
+ assert_equal(np.floor(arr), [1, 1])
+ assert_equal(np.ceil(arr), [2, 2])
+ assert_equal(np.trunc(arr), [3, 3])
+
+ def test_object_indirect(self):
+ """ test implementations via __float__ """
+ class C:
+ def __float__(self):
+ return -2.5
+
+ arr = np.array([C(), C()])
+ assert_equal(np.floor(arr), [-3, -3])
+ assert_equal(np.ceil(arr), [-2, -2])
+ with pytest.raises(TypeError):
+ np.trunc(arr) # consistent with math.trunc
+
+ def test_fraction(self):
+ f = Fraction(-4, 3)
+ assert_equal(np.floor(f), -2)
+ assert_equal(np.ceil(f), -1)
+ assert_equal(np.trunc(f), -1)
+
+
class TestComplexFunctions(object):
funcs = [np.arcsin, np.arccos, np.arctan, np.arcsinh, np.arccosh,
np.arctanh, np.sin, np.cos, np.tan, np.exp,