diff options
| author | Charles Harris <charlesr.harris@gmail.com> | 2019-04-22 19:04:52 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-04-22 19:04:52 -0600 |
| commit | 6473fa20407ccd7581ac90cb3ef5b921a4a75cd7 (patch) | |
| tree | 57a44c55676a82ee9d689e047751c40b009bde89 /numpy/core | |
| parent | 2b59dcb273f00e7be13cdc32c5f396a55781c2f4 (diff) | |
| parent | 55c7ed2c6822b5a5b30db7472c863dc1fa0c338f (diff) | |
| download | numpy-6473fa20407ccd7581ac90cb3ef5b921a4a75cd7.tar.gz | |
Merge pull request #13371 from eric-wieser/__floor__-and-__ceil__
BUG/ENH: Make floor, ceil, and trunc call the matching special methods
Diffstat (limited to 'numpy/core')
| -rw-r--r-- | numpy/core/code_generators/generate_umath.py | 6 | ||||
| -rw-r--r-- | numpy/core/src/umath/funcs.inc.src | 33 | ||||
| -rw-r--r-- | numpy/core/tests/test_umath.py | 37 |
3 files changed, 73 insertions, 3 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index 27a2ba1e1..e17523451 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -765,14 +765,14 @@ defdict = { docstrings.get('numpy.core.umath.ceil'), None, TD(flts, f='ceil', astype={'e':'f'}), - TD(P, f='ceil'), + TD(O, f='npy_ObjectCeil'), ), 'trunc': Ufunc(1, 1, None, docstrings.get('numpy.core.umath.trunc'), None, TD(flts, f='trunc', astype={'e':'f'}), - TD(P, f='trunc'), + TD(O, f='npy_ObjectTrunc'), ), 'fabs': Ufunc(1, 1, None, @@ -786,7 +786,7 @@ defdict = { docstrings.get('numpy.core.umath.floor'), None, TD(flts, f='floor', astype={'e':'f'}), - TD(P, f='floor'), + TD(O, f='npy_ObjectFloor'), ), 'rint': Ufunc(1, 1, None, diff --git a/numpy/core/src/umath/funcs.inc.src b/numpy/core/src/umath/funcs.inc.src index da2ab07f8..2acae3c37 100644 --- a/numpy/core/src/umath/funcs.inc.src +++ b/numpy/core/src/umath/funcs.inc.src @@ -160,6 +160,39 @@ npy_ObjectLogicalNot(PyObject *i1) } static PyObject * +npy_ObjectFloor(PyObject *obj) { + PyObject *math_floor_func = NULL; + + npy_cache_import("math", "floor", &math_floor_func); + if (math_floor_func == NULL) { + return NULL; + } + return PyObject_CallFunction(math_floor_func, "O", obj); +} + +static PyObject * +npy_ObjectCeil(PyObject *obj) { + PyObject *math_ceil_func = NULL; + + npy_cache_import("math", "ceil", &math_ceil_func); + if (math_ceil_func == NULL) { + return NULL; + } + return PyObject_CallFunction(math_ceil_func, "O", obj); +} + +static PyObject * +npy_ObjectTrunc(PyObject *obj) { + PyObject *math_trunc_func = NULL; + + npy_cache_import("math", "trunc", &math_trunc_func); + if (math_trunc_func == NULL) { + return NULL; + } + return PyObject_CallFunction(math_trunc_func, "O", obj); +} + +static PyObject * npy_ObjectGCD(PyObject *i1, PyObject *i2) { PyObject *gcd = NULL; diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 0eedd1e97..e0b0e11cf 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -5,6 +5,7 @@ import warnings import fnmatch import itertools import pytest +from fractions import Fraction import numpy.core.umath as ncu from numpy.core import _umath_tests as ncu_tests @@ -2460,6 +2461,42 @@ class TestRationalFunctions(object): assert_equal(np.gcd(2**100, 3**100), 1) +class TestRoundingFunctions(object): + + def test_object_direct(self): + """ test direct implementation of these magic methods """ + class C: + def __floor__(self): + return 1 + def __ceil__(self): + return 2 + def __trunc__(self): + return 3 + + arr = np.array([C(), C()]) + assert_equal(np.floor(arr), [1, 1]) + assert_equal(np.ceil(arr), [2, 2]) + assert_equal(np.trunc(arr), [3, 3]) + + def test_object_indirect(self): + """ test implementations via __float__ """ + class C: + def __float__(self): + return -2.5 + + arr = np.array([C(), C()]) + assert_equal(np.floor(arr), [-3, -3]) + assert_equal(np.ceil(arr), [-2, -2]) + with pytest.raises(TypeError): + np.trunc(arr) # consistent with math.trunc + + def test_fraction(self): + f = Fraction(-4, 3) + assert_equal(np.floor(f), -2) + assert_equal(np.ceil(f), -1) + assert_equal(np.trunc(f), -1) + + class TestComplexFunctions(object): funcs = [np.arcsin, np.arccos, np.arctan, np.arcsinh, np.arccosh, np.arctanh, np.sin, np.cos, np.tan, np.exp, |
