summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2009-12-10 19:12:41 +0000
committerPauli Virtanen <pav@iki.fi>2009-12-10 19:12:41 +0000
commit794a6c4511ced84c74bc8f2dd8cd8f277925a6ac (patch)
tree406d57b79c829f4ec2f3f7dfc5d28c5b7fa6daca
parent8d24b14a7eed3fba60bab462e873214cc9fa2a1f (diff)
downloadnumpy-794a6c4511ced84c74bc8f2dd8cd8f277925a6ac.tar.gz
ENH: emit ComplexWarning when casting complex to real (addresses #1319)
Casting complex numbers to real discards the imaginary part, which may be unexpected. For safety, emit a warning when this occurs.
-rw-r--r--doc/release/1.5.0-notes.rst29
-rw-r--r--numpy/core/numeric.py13
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c16
-rw-r--r--numpy/core/tests/test_multiarray.py10
4 files changed, 64 insertions, 4 deletions
diff --git a/doc/release/1.5.0-notes.rst b/doc/release/1.5.0-notes.rst
index 7a22620c8..42254b789 100644
--- a/doc/release/1.5.0-notes.rst
+++ b/doc/release/1.5.0-notes.rst
@@ -2,7 +2,36 @@
NumPy 1.5.0 Release Notes
=========================
+
+Plans
+=====
+
This release has the following aims:
* Python 3 compatibility
* :pep:`3118` compatibility
+
+
+Highlights
+==========
+
+
+New features
+============
+
+Warning on casting complex to real
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Numpy now emits a `numpy.ComplexWarning` when a complex number is cast
+into a real number. For example:
+
+ >>> x = np.array([1,2,3])
+ >>> x[:2] = np.array([1+2j, 1-2j])
+ ComplexWarning: Casting complex values to real discards the imaginary part
+
+The cast indeed discards the imaginary part, and this may not be the
+intended behavior in all cases, hence the warning. This warning can be
+turned off in the standard way:
+
+ >>> import warnings
+ >>> warnings.simplefilter("ignore", np.ComplexWarning)
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index e35fca07b..8c4fa2980 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -19,7 +19,8 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc',
'seterrcall', 'geterrcall', 'errstate', 'flatnonzero',
'Inf', 'inf', 'infty', 'Infinity',
'nan', 'NaN', 'False_', 'True_', 'bitwise_not',
- 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS']
+ 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS',
+ 'ComplexWarning']
import sys
import warnings
@@ -32,6 +33,16 @@ from numerictypes import *
if sys.version_info[0] < 3:
__all__.extend(['getbuffer', 'newbuffer'])
+class ComplexWarning(RuntimeWarning):
+ """
+ Warning that is raised when casting complex numbers to real.
+
+ Casting a complex number to real discards its imaginary part, and
+ this behavior may not be what is intended in all cases.
+
+ """
+ pass
+
bitwise_not = invert
CLIP = multiarray.CLIP
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 049a98041..229b8136d 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -106,11 +106,21 @@ PyArray_GetCastFunc(PyArray_Descr *descr, int type_num)
castfunc = PyCObject_AsVoidPtr(cobj);
}
}
- if (castfunc) {
- return castfunc;
+ }
+ if (PyTypeNum_ISCOMPLEX(descr->type_num) &&
+ !PyTypeNum_ISCOMPLEX(type_num)) {
+ PyObject *cls = NULL, *obj = NULL;
+ obj = PyImport_ImportModule("numpy.core");
+ if (obj) {
+ cls = PyObject_GetAttrString(obj, "ComplexWarning");
+ Py_DECREF(obj);
}
+ PyErr_WarnEx(cls,
+ "Casting complex values to real discards the imaginary "
+ "part", 0);
+ Py_XDECREF(cls);
}
- else {
+ if (castfunc) {
return castfunc;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 45ca22fea..4fc78b128 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1425,6 +1425,16 @@ class TestStackedNeighborhoodIter(TestCase):
[-1, 2], NEIGH_MODE['circular'])
assert_array_equal(l, r)
+class TestWarnings(object):
+ def test_complex_warning(self):
+ import warnings
+
+ x = np.array([1,2])
+ y = np.array([1-2j,1+2j])
+
+ warnings.simplefilter("error", np.ComplexWarning)
+ assert_raises(np.ComplexWarning, x.__setitem__, slice(None), y)
+ warnings.simplefilter("default", np.ComplexWarning)
if sys.version_info >= (2, 6):
class TestNewBufferProtocol(object):