summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2022-12-02 13:41:03 +0200
committerGitHub <noreply@github.com>2022-12-02 13:41:03 +0200
commit9989e1e43fb24d163887f450ef15e6f270b8e15b (patch)
tree10ca8bbb1e7702aa4402ce710c670c9208ab7051 /numpy
parent8ff45c5bb520db04af8720bf1d34a392a8d2561a (diff)
parent1b57a63ca3bff38e9891cae2d23acb8d5f063883 (diff)
downloadnumpy-9989e1e43fb24d163887f450ef15e6f270b8e15b.tar.gz
Merge pull request #18535 from prithvitewatia/Issue18378
BUG: Fix <complex 0>^{non-zero}
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/npymath/npy_math_complex.c.src54
-rw-r--r--numpy/core/tests/test_umath.py26
2 files changed, 60 insertions, 20 deletions
diff --git a/numpy/core/src/npymath/npy_math_complex.c.src b/numpy/core/src/npymath/npy_math_complex.c.src
index e0c078444..eff7f8e13 100644
--- a/numpy/core/src/npymath/npy_math_complex.c.src
+++ b/numpy/core/src/npymath/npy_math_complex.c.src
@@ -443,29 +443,43 @@ npy_cpow@c@ (@ctype@ a, @ctype@ b)
@type@ bi = npy_cimag@c@(b);
@ctype@ r;
+ /*
+ * Checking if in a^b, if b is zero.
+ * If a is not zero then by definition of logarithm a^0 is 1.
+ * If a is also zero then 0^0 is best defined as 1.
+ */
if (br == 0. && bi == 0.) {
return npy_cpack@c@(1., 0.);
}
- if (ar == 0. && ai == 0.) {
- if (br > 0 && bi == 0) {
- return npy_cpack@c@(0., 0.);
- }
- else {
- volatile @type@ tmp = NPY_INFINITY@C@;
- /*
- * NB: there are four complex zeros; c0 = (+-0, +-0), so that
- * unlike for reals, c0**p, with `p` negative is in general
- * ill-defined.
- *
- * c0**z with z complex is also ill-defined.
- */
- r = npy_cpack@c@(NPY_NAN@C@, NPY_NAN@C@);
-
- /* Raise invalid */
- tmp -= NPY_INFINITY@C@;
- ar = tmp;
- return r;
- }
+ /* case 0^b
+ * If a is a complex zero (ai=ar=0), then the result depends
+ * upon values of br and bi. The result is either:
+ * 0 (in magnitude), undefined or 1.
+ * The later case is for br=bi=0 and independent of ar and ai
+ * but is handled above).
+ */
+ else if (ar == 0. && ai == 0.) {
+ /*
+ * If the real part of b is positive (br>0) then this is
+ * the zero complex with positive sign on both the
+ * real and imaginary part.
+ */
+ if (br > 0) {
+ return npy_cpack@c@(0., 0.);
+ }
+ /* else we are in the case where the
+ * real part of b is negative (br<0).
+ * Here we should return a complex nan
+ * and raise FloatingPointError: invalid value...
+ */
+
+ /* Raise invalid value by calling inf - inf*/
+ volatile @type@ tmp = NPY_INFINITY@C@;
+ tmp -= NPY_INFINITY@C@;
+ ar = tmp;
+
+ r = npy_cpack@c@(NPY_NAN@C@, NPY_NAN@C@);
+ return r;
}
if (bi == 0 && (n=(npy_intp)br) == br) {
if (n == 1) {
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index e0f91e326..1160eca54 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1069,6 +1069,32 @@ class TestPower:
assert_complex_equal(np.power(zero, -p), cnan)
assert_complex_equal(np.power(zero, -1+0.2j), cnan)
+ # Testing 0^{Non-zero} issue 18378
+ def test_zero_power_nonzero(self):
+ zero = np.array([0.0+0.0j])
+ cnan = np.array([complex(np.nan, np.nan)])
+
+ def assert_complex_equal(x, y):
+ assert_array_equal(x.real, y.real)
+ assert_array_equal(x.imag, y.imag)
+
+ #Complex powers with positive real part will not generate a warning
+ assert_complex_equal(np.power(zero, 1+4j), zero)
+ assert_complex_equal(np.power(zero, 2-3j), zero)
+ #Testing zero values when real part is greater than zero
+ assert_complex_equal(np.power(zero, 1+1j), zero)
+ assert_complex_equal(np.power(zero, 1+0j), zero)
+ assert_complex_equal(np.power(zero, 1-1j), zero)
+ #Complex powers will negative real part or 0 (provided imaginary
+ # part is not zero) will generate a NAN and hence a RUNTIME warning
+ with pytest.warns(expected_warning=RuntimeWarning) as r:
+ assert_complex_equal(np.power(zero, -1+1j), cnan)
+ assert_complex_equal(np.power(zero, -2-3j), cnan)
+ assert_complex_equal(np.power(zero, -7+0j), cnan)
+ assert_complex_equal(np.power(zero, 0+1j), cnan)
+ assert_complex_equal(np.power(zero, 0-1j), cnan)
+ assert len(r) == 5
+
def test_fast_power(self):
x = np.array([1, 2, 3], np.int16)
res = x**2.0