summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2009-12-07 22:22:45 +0000
committerPauli Virtanen <pav@iki.fi>2009-12-07 22:22:45 +0000
commit801f2b2eed815d2489acd44ed9c8a7e581a4d433 (patch)
tree9f9fa39a177c07f2851f9e06ac6acc2931142508
parent37f318c8be5c0f0eb7e7d217690446fbbc0d50d2 (diff)
downloadnumpy-801f2b2eed815d2489acd44ed9c8a7e581a4d433.tar.gz
BUG: Fix bugs in complex pow (fixes #1313)
Thanks to Francesc Alted.
-rw-r--r--numpy/core/src/npymath/npy_math_complex.c.src15
-rw-r--r--numpy/core/src/scalarmathmodule.c.src2
-rw-r--r--numpy/core/tests/test_umath_complex.py26
3 files changed, 34 insertions, 9 deletions
diff --git a/numpy/core/src/npymath/npy_math_complex.c.src b/numpy/core/src/npymath/npy_math_complex.c.src
index 583328782..b954e93a9 100644
--- a/numpy/core/src/npymath/npy_math_complex.c.src
+++ b/numpy/core/src/npymath/npy_math_complex.c.src
@@ -195,17 +195,16 @@
#ifndef HAVE_CPOW@C@
@ctype@ npy_cpow@c@ (@ctype@ x, @ctype@ y)
{
- @ctype@ b, p;
- @type@ bx, by, px, py;
+ @ctype@ b;
+ @type@ br, bi, yr, yi;
+ yr = npy_creal@c@(y);
+ yi = npy_cimag@c@(y);
b = npy_clog@c@(x);
- p = npy_clog@c@(y);
- bx = npy_creal@c@(b);
- by = npy_cimag@c@(b);
- px = npy_creal@c@(p);
- py = npy_cimag@c@(p);
+ br = npy_creal@c@(b);
+ bi = npy_cimag@c@(b);
- return npy_cexp@c@(npy_cpack@c@(bx * px - by * py, bx * py + by * px));
+ return npy_cexp@c@(npy_cpack@c@(br * yr - bi * yi, br * yi + bi * yr));
}
#endif
diff --git a/numpy/core/src/scalarmathmodule.c.src b/numpy/core/src/scalarmathmodule.c.src
index d2be8d69c..5e8cd812f 100644
--- a/numpy/core/src/scalarmathmodule.c.src
+++ b/numpy/core/src/scalarmathmodule.c.src
@@ -806,7 +806,7 @@ static PyObject *
* as a function call.
*/
#if @cmplx@
- if (arg2.real == 0 && arg1.real == 0) {
+ if (arg2.real == 0 && arg2.imag == 0) {
out1.real = out.real = 1;
out1.imag = out.imag = 0;
}
diff --git a/numpy/core/tests/test_umath_complex.py b/numpy/core/tests/test_umath_complex.py
index b041a558b..c180b5e4a 100644
--- a/numpy/core/tests/test_umath_complex.py
+++ b/numpy/core/tests/test_umath_complex.py
@@ -309,6 +309,32 @@ class TestCpow(TestCase):
for i in range(len(x)):
assert_almost_equal(y[i], y_r[i])
+ def test_scalar(self):
+ x = np.array([1, 1j, 2, 2.5+.37j, np.inf, np.nan])
+ y = np.array([1, 1j, -0.5+1.5j, -0.5+1.5j, 2, 3])
+ lx = range(len(x))
+ # Compute the values for complex type in python
+ p_r = [complex(x[i]) ** complex(y[i]) for i in lx]
+ # Substitute a result allowed by C99 standard
+ p_r[4] = complex(np.inf, np.nan)
+ # Do the same with numpy complex scalars
+ n_r = [x[i] ** y[i] for i in lx]
+ for i in lx:
+ assert_almost_equal(n_r[i], p_r[i], err_msg='Loop %d\n' % i)
+
+ def test_array(self):
+ x = np.array([1, 1j, 2, 2.5+.37j, np.inf, np.nan])
+ y = np.array([1, 1j, -0.5+1.5j, -0.5+1.5j, 2, 3])
+ lx = range(len(x))
+ # Compute the values for complex type in python
+ p_r = [complex(x[i]) ** complex(y[i]) for i in lx]
+ # Substitute a result allowed by C99 standard
+ p_r[4] = complex(np.inf, np.nan)
+ # Do the same with numpy arrays
+ n_r = x ** y
+ for i in lx:
+ assert_almost_equal(n_r[i], p_r[i], err_msg='Loop %d\n' % i)
+
class TestCabs(object):
def test_simple(self):
x = np.array([1+1j, 0+2j, 1+2j, np.inf, np.nan])