diff options
author | Pauli Virtanen <pav@iki.fi> | 2009-12-10 19:12:41 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2009-12-10 19:12:41 +0000 |
commit | 794a6c4511ced84c74bc8f2dd8cd8f277925a6ac (patch) | |
tree | 406d57b79c829f4ec2f3f7dfc5d28c5b7fa6daca | |
parent | 8d24b14a7eed3fba60bab462e873214cc9fa2a1f (diff) | |
download | numpy-794a6c4511ced84c74bc8f2dd8cd8f277925a6ac.tar.gz |
ENH: emit ComplexWarning when casting complex to real (addresses #1319)
Casting complex numbers to real discards the imaginary part, which may
be unexpected. For safety, emit a warning when this occurs.
-rw-r--r-- | doc/release/1.5.0-notes.rst | 29 | ||||
-rw-r--r-- | numpy/core/numeric.py | 13 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 16 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 10 |
4 files changed, 64 insertions, 4 deletions
diff --git a/doc/release/1.5.0-notes.rst b/doc/release/1.5.0-notes.rst index 7a22620c8..42254b789 100644 --- a/doc/release/1.5.0-notes.rst +++ b/doc/release/1.5.0-notes.rst @@ -2,7 +2,36 @@ NumPy 1.5.0 Release Notes ========================= + +Plans +===== + This release has the following aims: * Python 3 compatibility * :pep:`3118` compatibility + + +Highlights +========== + + +New features +============ + +Warning on casting complex to real +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Numpy now emits a `numpy.ComplexWarning` when a complex number is cast +into a real number. For example: + + >>> x = np.array([1,2,3]) + >>> x[:2] = np.array([1+2j, 1-2j]) + ComplexWarning: Casting complex values to real discards the imaginary part + +The cast indeed discards the imaginary part, and this may not be the +intended behavior in all cases, hence the warning. This warning can be +turned off in the standard way: + + >>> import warnings + >>> warnings.simplefilter("ignore", np.ComplexWarning) diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index e35fca07b..8c4fa2980 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -19,7 +19,8 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc', 'seterrcall', 'geterrcall', 'errstate', 'flatnonzero', 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_', 'bitwise_not', - 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS'] + 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', + 'ComplexWarning'] import sys import warnings @@ -32,6 +33,16 @@ from numerictypes import * if sys.version_info[0] < 3: __all__.extend(['getbuffer', 'newbuffer']) +class ComplexWarning(RuntimeWarning): + """ + Warning that is raised when casting complex numbers to real. + + Casting a complex number to real discards its imaginary part, and + this behavior may not be what is intended in all cases. + + """ + pass + bitwise_not = invert CLIP = multiarray.CLIP diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 049a98041..229b8136d 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -106,11 +106,21 @@ PyArray_GetCastFunc(PyArray_Descr *descr, int type_num) castfunc = PyCObject_AsVoidPtr(cobj); } } - if (castfunc) { - return castfunc; + } + if (PyTypeNum_ISCOMPLEX(descr->type_num) && + !PyTypeNum_ISCOMPLEX(type_num)) { + PyObject *cls = NULL, *obj = NULL; + obj = PyImport_ImportModule("numpy.core"); + if (obj) { + cls = PyObject_GetAttrString(obj, "ComplexWarning"); + Py_DECREF(obj); } + PyErr_WarnEx(cls, + "Casting complex values to real discards the imaginary " + "part", 0); + Py_XDECREF(cls); } - else { + if (castfunc) { return castfunc; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 45ca22fea..4fc78b128 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1425,6 +1425,16 @@ class TestStackedNeighborhoodIter(TestCase): [-1, 2], NEIGH_MODE['circular']) assert_array_equal(l, r) +class TestWarnings(object): + def test_complex_warning(self): + import warnings + + x = np.array([1,2]) + y = np.array([1-2j,1+2j]) + + warnings.simplefilter("error", np.ComplexWarning) + assert_raises(np.ComplexWarning, x.__setitem__, slice(None), y) + warnings.simplefilter("default", np.ComplexWarning) if sys.version_info >= (2, 6): class TestNewBufferProtocol(object): |