summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2010-07-31 13:07:42 +0000
committerPauli Virtanen <pav@iki.fi>2010-07-31 13:07:42 +0000
commit869ccb83cbff2cb10842f7b3080f460ca8ba8b55 (patch)
treebf237a530ecb061d93c437f8079100a1e8674a28
parent400a612118fb24f28ee5c83d027432bc66492eff (diff)
downloadnumpy-869ccb83cbff2cb10842f7b3080f460ca8ba8b55.tar.gz
BUG: core/umath: fix powers of complex 0 (#1271)
-rw-r--r--numpy/core/src/umath/funcs.inc.src17
-rw-r--r--numpy/core/tests/test_umath.py25
2 files changed, 41 insertions, 1 deletions
diff --git a/numpy/core/src/umath/funcs.inc.src b/numpy/core/src/umath/funcs.inc.src
index d8127322c..5dc58e990 100644
--- a/numpy/core/src/umath/funcs.inc.src
+++ b/numpy/core/src/umath/funcs.inc.src
@@ -238,7 +238,22 @@ nc_pow@c@(c@typ@ *a, c@typ@ *b, c@typ@ *r)
return;
}
if (ar == 0. && ai == 0.) {
- *r = npy_cpack@c@(0., 0.);
+ if (br > 0 && bi == 0) {
+ *r = npy_cpack@c@(0., 0.);
+ }
+ else {
+ /* NB: there are four complex zeros; c0 = (+-0, +-0), so that unlike
+ * for reals, c0**p, with `p` negative is in general
+ * ill-defined.
+ *
+ * c0**z with z complex is also ill-defined.
+ */
+ *r = npy_cpack@c@(NPY_NAN, NPY_NAN);
+
+ /* Raise invalid */
+ ar = NPY_INFINITY;
+ ar = ar - ar;
+ }
return;
}
if (bi == 0 && (n=(intp)br) == br) {
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 2741287ec..2dace8c16 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -86,6 +86,31 @@ class TestPower(TestCase):
finally:
np.seterr(**err)
+ def test_power_zero(self):
+ # ticket #1271
+ zero = np.array([0j])
+ one = np.array([1+0j])
+ cinf = np.array([complex(np.inf, 0)])
+ cnan = np.array([complex(np.nan, np.nan)])
+
+ def assert_complex_equal(x, y):
+ x, y = np.asarray(x), np.asarray(y)
+ assert_array_equal(x.real, y.real)
+ assert_array_equal(x.imag, y.imag)
+
+ # positive powers
+ for p in [0.33, 0.5, 1, 1.5, 2, 3, 4, 5, 6.6]:
+ assert_complex_equal(np.power(zero, p), zero)
+
+ # zero power
+ assert_complex_equal(np.power(zero, 0), one)
+ assert_complex_equal(np.power(zero, 0+1j), cnan)
+
+ # negative power
+ for p in [0.33, 0.5, 1, 1.5, 2, 3, 4, 5, 6.6]:
+ assert_complex_equal(np.power(zero, -p), cnan)
+ assert_complex_equal(np.power(zero, -1+0.2j), cnan)
+
class TestLog2(TestCase):
def test_log2_values(self) :