summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorHameer Abbasi <einstein.edison@gmail.com>2020-04-10 20:19:36 +0200
committerGitHub <noreply@github.com>2020-04-10 13:19:36 -0500
commit1f3650643f341af8ca5c12529a479d73d3103895 (patch)
tree11ba7e500599d4aab078d8cb21b5c5f9e02ebe4e /numpy
parent42228fcb513c08729f822fa6be70e16a7a68067f (diff)
downloadnumpy-1f3650643f341af8ca5c12529a479d73d3103895.tar.gz
BUG,DEP: Make `scalar.__round__()` behave like pythons round (#15840)
See issue gh-15297 and related mailing list discussion. This PR bring scalar.__round__() in line with python, so that `round(scalar)` always returns a python integer, while `round(scalar, ndigits=0)` returns the same type. Since complex numbers are not supported in Python, and cannot be reasonably cast to integers they are deprecated. Closes gh-15297
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src72
-rw-r--r--numpy/core/tests/test_deprecations.py24
-rw-r--r--numpy/core/tests/test_numeric.py48
3 files changed, 137 insertions, 7 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index eafa13ff2..2f1767391 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1549,6 +1549,58 @@ gentype_@name@(PyObject *self, PyObject *args, PyObject *kwds)
}
/**end repeat**/
+
+/**begin repeat
+ * #name = integer, floating, complexfloating#
+ * #complex = 0, 0, 1#
+ */
+static PyObject *
+@name@type_dunder_round(PyObject *self, PyObject *args, PyObject *kwds)
+{
+ static char *kwlist[] = {"ndigits", NULL};
+ PyObject *ndigits = Py_None;
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O:__round__", kwlist, &ndigits)) {
+ return NULL;
+ }
+
+#if @complex@
+ if (DEPRECATE("The Python built-in `round` is deprecated for complex "
+ "scalars, and will raise a `TypeError` in a future release. "
+ "Use `np.round` or `scalar.round` instead.") < 0) {
+ return NULL;
+ }
+#endif
+
+ PyObject *tup;
+ if (ndigits == Py_None) {
+ tup = PyTuple_Pack(0);
+ }
+ else {
+ tup = PyTuple_Pack(1, ndigits);
+ }
+
+ if (tup == NULL) {
+ return NULL;
+ }
+
+ PyObject *obj = gentype_round(self, tup, NULL);
+ Py_DECREF(tup);
+ if (obj == NULL) {
+ return NULL;
+ }
+
+#if !@complex@
+ if (ndigits == Py_None) {
+ PyObject *ret = PyNumber_Long(obj);
+ Py_DECREF(obj);
+ return ret;
+ }
+#endif
+
+ return obj;
+}
+/**end repeat**/
+
static PyObject *
voidtype_getfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
{
@@ -2056,10 +2108,6 @@ static PyMethodDef gentype_methods[] = {
{"round",
(PyCFunction)gentype_round,
METH_VARARGS | METH_KEYWORDS, NULL},
- /* Hook for the round() builtin */
- {"__round__",
- (PyCFunction)gentype_round,
- METH_VARARGS | METH_KEYWORDS, NULL},
/* For the format function */
{"__format__",
gentype_format,
@@ -2130,6 +2178,18 @@ static PyMethodDef @name@type_methods[] = {
/**end repeat**/
/**begin repeat
+ * #name = integer,floating, complexfloating#
+ */
+static PyMethodDef @name@type_methods[] = {
+ /* Hook for the round() builtin */
+ {"__round__",
+ (PyCFunction)@name@type_dunder_round,
+ METH_VARARGS | METH_KEYWORDS, NULL},
+ {NULL, NULL, 0, NULL} /* sentinel */
+};
+/**end repeat**/
+
+/**begin repeat
* #name = half,float,double,longdouble#
*/
static PyMethodDef @name@type_methods[] = {
@@ -3655,8 +3715,8 @@ initialize_numeric_types(void)
/**end repeat**/
/**begin repeat
- * #name = cfloat, clongdouble#
- * #NAME = CFloat, CLongDouble#
+ * #name = cfloat, clongdouble, floating, integer, complexfloating#
+ * #NAME = CFloat, CLongDouble, Floating, Integer, ComplexFloating#
*/
Py@NAME@ArrType_Type.tp_methods = @name@type_methods;
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index 5d35bde6c..82d24e0f7 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -594,3 +594,27 @@ class TestDTypeCoercion(_DeprecationTestCase):
for scalar_type in [type, dict, list, tuple]:
# Typical python types are coerced to object currently:
self.assert_not_deprecated(np.dtype, args=(scalar_type,))
+
+
+class BuiltInRoundComplexDType(_DeprecationTestCase):
+ # 2020-03-31 1.19.0
+ deprecated_types = [np.csingle, np.cdouble, np.clongdouble]
+ not_deprecated_types = [
+ np.int8, np.int16, np.int32, np.int64,
+ np.uint8, np.uint16, np.uint32, np.uint64,
+ np.float16, np.float32, np.float64,
+ ]
+
+ def test_deprecated(self):
+ for scalar_type in self.deprecated_types:
+ scalar = scalar_type(0)
+ self.assert_deprecated(round, args=(scalar,))
+ self.assert_deprecated(round, args=(scalar, 0))
+ self.assert_deprecated(round, args=(scalar,), kwargs={'ndigits': 0})
+
+ def test_not_deprecated(self):
+ for scalar_type in self.not_deprecated_types:
+ scalar = scalar_type(0)
+ self.assert_not_deprecated(round, args=(scalar,))
+ self.assert_not_deprecated(round, args=(scalar, 0))
+ self.assert_not_deprecated(round, args=(scalar,), kwargs={'ndigits': 0})
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 1bcfe50a4..bcc6a0c4e 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -3,6 +3,7 @@ import warnings
import itertools
import platform
import pytest
+import math
from decimal import Decimal
import numpy as np
@@ -11,7 +12,7 @@ from numpy.random import rand, randint, randn
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_raises_regex,
assert_array_equal, assert_almost_equal, assert_array_almost_equal,
- assert_warns, HAS_REFCOUNT
+ assert_warns, assert_array_max_ulp, HAS_REFCOUNT
)
from hypothesis import assume, given, strategies as st
@@ -139,6 +140,51 @@ class TestNonarrayArgs:
arr = [1.56, 72.54, 6.35, 3.25]
tgt = [1.6, 72.5, 6.4, 3.2]
assert_equal(np.around(arr, decimals=1), tgt)
+ s = np.float64(1.)
+ assert_(isinstance(s.round(), np.float64))
+ assert_equal(s.round(), 1.)
+
+ @pytest.mark.parametrize('dtype', [
+ np.int8, np.int16, np.int32, np.int64,
+ np.uint8, np.uint16, np.uint32, np.uint64,
+ np.float16, np.float32, np.float64,
+ ])
+ def test_dunder_round(self, dtype):
+ s = dtype(1)
+ assert_(isinstance(round(s), int))
+ assert_(isinstance(round(s, None), int))
+ assert_(isinstance(round(s, ndigits=None), int))
+ assert_equal(round(s), 1)
+ assert_equal(round(s, None), 1)
+ assert_equal(round(s, ndigits=None), 1)
+
+ @pytest.mark.parametrize('val, ndigits', [
+ pytest.param(2**31 - 1, -1,
+ marks=pytest.mark.xfail(reason="Out of range of int32")
+ ),
+ (2**31 - 1, 1-math.ceil(math.log10(2**31 - 1))),
+ (2**31 - 1, -math.ceil(math.log10(2**31 - 1)))
+ ])
+ def test_dunder_round_edgecases(self, val, ndigits):
+ assert_equal(round(val, ndigits), round(np.int32(val), ndigits))
+
+ def test_dunder_round_accuracy(self):
+ f = np.float64(5.1 * 10**73)
+ assert_(isinstance(round(f, -73), np.float64))
+ assert_array_max_ulp(round(f, -73), 5.0 * 10**73)
+ assert_(isinstance(round(f, ndigits=-73), np.float64))
+ assert_array_max_ulp(round(f, ndigits=-73), 5.0 * 10**73)
+
+ i = np.int64(501)
+ assert_(isinstance(round(i, -2), np.int64))
+ assert_array_max_ulp(round(i, -2), 500)
+ assert_(isinstance(round(i, ndigits=-2), np.int64))
+ assert_array_max_ulp(round(i, ndigits=-2), 500)
+
+ @pytest.mark.xfail(raises=AssertionError, reason="gh-15896")
+ def test_round_py_consistency(self):
+ f = 5.1 * 10**73
+ assert_equal(round(np.float64(f), -73), round(f, -73))
def test_searchsorted(self):
arr = [-8, -5, -1, 3, 6, 10]