diff options
author | Hameer Abbasi <einstein.edison@gmail.com> | 2020-04-10 20:19:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-10 13:19:36 -0500 |
commit | 1f3650643f341af8ca5c12529a479d73d3103895 (patch) | |
tree | 11ba7e500599d4aab078d8cb21b5c5f9e02ebe4e /numpy | |
parent | 42228fcb513c08729f822fa6be70e16a7a68067f (diff) | |
download | numpy-1f3650643f341af8ca5c12529a479d73d3103895.tar.gz |
BUG,DEP: Make `scalar.__round__()` behave like pythons round (#15840)
See issue gh-15297 and related mailing list discussion.
This PR bring scalar.__round__() in line with python, so that `round(scalar)` always returns a python integer, while `round(scalar, ndigits=0)` returns the same type. Since complex numbers are not supported in Python, and cannot be reasonably cast to integers they are deprecated.
Closes gh-15297
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 72 | ||||
-rw-r--r-- | numpy/core/tests/test_deprecations.py | 24 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 48 |
3 files changed, 137 insertions, 7 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index eafa13ff2..2f1767391 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1549,6 +1549,58 @@ gentype_@name@(PyObject *self, PyObject *args, PyObject *kwds) } /**end repeat**/ + +/**begin repeat + * #name = integer, floating, complexfloating# + * #complex = 0, 0, 1# + */ +static PyObject * +@name@type_dunder_round(PyObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"ndigits", NULL}; + PyObject *ndigits = Py_None; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O:__round__", kwlist, &ndigits)) { + return NULL; + } + +#if @complex@ + if (DEPRECATE("The Python built-in `round` is deprecated for complex " + "scalars, and will raise a `TypeError` in a future release. " + "Use `np.round` or `scalar.round` instead.") < 0) { + return NULL; + } +#endif + + PyObject *tup; + if (ndigits == Py_None) { + tup = PyTuple_Pack(0); + } + else { + tup = PyTuple_Pack(1, ndigits); + } + + if (tup == NULL) { + return NULL; + } + + PyObject *obj = gentype_round(self, tup, NULL); + Py_DECREF(tup); + if (obj == NULL) { + return NULL; + } + +#if !@complex@ + if (ndigits == Py_None) { + PyObject *ret = PyNumber_Long(obj); + Py_DECREF(obj); + return ret; + } +#endif + + return obj; +} +/**end repeat**/ + static PyObject * voidtype_getfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds) { @@ -2056,10 +2108,6 @@ static PyMethodDef gentype_methods[] = { {"round", (PyCFunction)gentype_round, METH_VARARGS | METH_KEYWORDS, NULL}, - /* Hook for the round() builtin */ - {"__round__", - (PyCFunction)gentype_round, - METH_VARARGS | METH_KEYWORDS, NULL}, /* For the format function */ {"__format__", gentype_format, @@ -2130,6 +2178,18 @@ static PyMethodDef @name@type_methods[] = { /**end repeat**/ /**begin repeat + * #name = integer,floating, complexfloating# + */ +static PyMethodDef @name@type_methods[] = { + /* Hook for the round() builtin */ + {"__round__", + (PyCFunction)@name@type_dunder_round, + METH_VARARGS | METH_KEYWORDS, NULL}, + {NULL, NULL, 0, NULL} /* sentinel */ +}; +/**end repeat**/ + +/**begin repeat * #name = half,float,double,longdouble# */ static PyMethodDef @name@type_methods[] = { @@ -3655,8 +3715,8 @@ initialize_numeric_types(void) /**end repeat**/ /**begin repeat - * #name = cfloat, clongdouble# - * #NAME = CFloat, CLongDouble# + * #name = cfloat, clongdouble, floating, integer, complexfloating# + * #NAME = CFloat, CLongDouble, Floating, Integer, ComplexFloating# */ Py@NAME@ArrType_Type.tp_methods = @name@type_methods; diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py index 5d35bde6c..82d24e0f7 100644 --- a/numpy/core/tests/test_deprecations.py +++ b/numpy/core/tests/test_deprecations.py @@ -594,3 +594,27 @@ class TestDTypeCoercion(_DeprecationTestCase): for scalar_type in [type, dict, list, tuple]: # Typical python types are coerced to object currently: self.assert_not_deprecated(np.dtype, args=(scalar_type,)) + + +class BuiltInRoundComplexDType(_DeprecationTestCase): + # 2020-03-31 1.19.0 + deprecated_types = [np.csingle, np.cdouble, np.clongdouble] + not_deprecated_types = [ + np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, + ] + + def test_deprecated(self): + for scalar_type in self.deprecated_types: + scalar = scalar_type(0) + self.assert_deprecated(round, args=(scalar,)) + self.assert_deprecated(round, args=(scalar, 0)) + self.assert_deprecated(round, args=(scalar,), kwargs={'ndigits': 0}) + + def test_not_deprecated(self): + for scalar_type in self.not_deprecated_types: + scalar = scalar_type(0) + self.assert_not_deprecated(round, args=(scalar,)) + self.assert_not_deprecated(round, args=(scalar, 0)) + self.assert_not_deprecated(round, args=(scalar,), kwargs={'ndigits': 0}) diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 1bcfe50a4..bcc6a0c4e 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -3,6 +3,7 @@ import warnings import itertools import platform import pytest +import math from decimal import Decimal import numpy as np @@ -11,7 +12,7 @@ from numpy.random import rand, randint, randn from numpy.testing import ( assert_, assert_equal, assert_raises, assert_raises_regex, assert_array_equal, assert_almost_equal, assert_array_almost_equal, - assert_warns, HAS_REFCOUNT + assert_warns, assert_array_max_ulp, HAS_REFCOUNT ) from hypothesis import assume, given, strategies as st @@ -139,6 +140,51 @@ class TestNonarrayArgs: arr = [1.56, 72.54, 6.35, 3.25] tgt = [1.6, 72.5, 6.4, 3.2] assert_equal(np.around(arr, decimals=1), tgt) + s = np.float64(1.) + assert_(isinstance(s.round(), np.float64)) + assert_equal(s.round(), 1.) + + @pytest.mark.parametrize('dtype', [ + np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, + ]) + def test_dunder_round(self, dtype): + s = dtype(1) + assert_(isinstance(round(s), int)) + assert_(isinstance(round(s, None), int)) + assert_(isinstance(round(s, ndigits=None), int)) + assert_equal(round(s), 1) + assert_equal(round(s, None), 1) + assert_equal(round(s, ndigits=None), 1) + + @pytest.mark.parametrize('val, ndigits', [ + pytest.param(2**31 - 1, -1, + marks=pytest.mark.xfail(reason="Out of range of int32") + ), + (2**31 - 1, 1-math.ceil(math.log10(2**31 - 1))), + (2**31 - 1, -math.ceil(math.log10(2**31 - 1))) + ]) + def test_dunder_round_edgecases(self, val, ndigits): + assert_equal(round(val, ndigits), round(np.int32(val), ndigits)) + + def test_dunder_round_accuracy(self): + f = np.float64(5.1 * 10**73) + assert_(isinstance(round(f, -73), np.float64)) + assert_array_max_ulp(round(f, -73), 5.0 * 10**73) + assert_(isinstance(round(f, ndigits=-73), np.float64)) + assert_array_max_ulp(round(f, ndigits=-73), 5.0 * 10**73) + + i = np.int64(501) + assert_(isinstance(round(i, -2), np.int64)) + assert_array_max_ulp(round(i, -2), 500) + assert_(isinstance(round(i, ndigits=-2), np.int64)) + assert_array_max_ulp(round(i, ndigits=-2), 500) + + @pytest.mark.xfail(raises=AssertionError, reason="gh-15896") + def test_round_py_consistency(self): + f = 5.1 * 10**73 + assert_equal(round(np.float64(f), -73), round(f, -73)) def test_searchsorted(self): arr = [-8, -5, -1, 3, 6, 10] |