summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorPrithvi <prithvisinghtewatia@gmail.com>2021-05-18 18:02:02 +0530
committerPrithvi <prithvisinghtewatia@gmail.com>2022-06-17 21:13:06 +0530
commit890ad666f3968a5cac58b8d3ce734a5bfa1b9825 (patch)
treee86666a8f2618d4c26fa8c21d536e391ca36bfb9 /numpy
parent9c91f21e88697a51b03b8b2d6b63b0d0bb2911b7 (diff)
downloadnumpy-890ad666f3968a5cac58b8d3ce734a5bfa1b9825.tar.gz
Fixed 0 sign and added new tests
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/npymath/npy_math_complex.c.src61
-rw-r--r--numpy/core/tests/test_umath.py13
2 files changed, 50 insertions, 24 deletions
diff --git a/numpy/core/src/npymath/npy_math_complex.c.src b/numpy/core/src/npymath/npy_math_complex.c.src
index 0a786234e..7508e34e8 100644
--- a/numpy/core/src/npymath/npy_math_complex.c.src
+++ b/numpy/core/src/npymath/npy_math_complex.c.src
@@ -445,27 +445,46 @@ npy_cpow@c@ (@ctype@ a, @ctype@ b)
@type@ bi = npy_cimag@c@(b);
@ctype@ r;
- /*Checking if in a^b, if b is zero.
- If a is not zero then by definition of logarithm a^0 is one.
- If a is also zero then as per IEEE 0^0 is best defined as 1.*/
- if(br == 0. && bi == 0.){
- return npy_cpack@c@(1.,0.);
- }
- /*Here it is already tested that for a^b, b is not zero.
- So we are checking behaviour of a.*/
- else if(ar == 0. && ai == 0.){
- /*Checking if real part of power is greater than zero then the result is zero.
- If not the result is undefined as it blows up or oscillates.*/
- if(br > 0){
- return npy_cpack@c@(0.,0.);
- }
- else{
- //Generating an invalid
- volatile @type@ tmp=NPY_INFINITY@C@;
- tmp-=NPY_INFINITY@C@;
- ar=tmp;
- return npy_cpack@c@(NPY_NAN@C@, NPY_NAN@C@);
- }
+ /*
+ * Checking if in a^b, if b is zero.
+ * If a is not zero then by definition of logarithm a^0 is 1.
+ * If a is also zero then as per IEEE 0^0 is best defined as 1.
+ */
+ if (br == 0. && bi == 0.) {
+ return npy_cpack@c@(1., 0.);
+ }
+ /*
+ * If mantissa is zero, then result depends upon values of
+ * br and bi. The result is either 0 (in magnitude) or
+ * undefined ( or 1 for br=bi=0 and independent of ar and ai
+ * but that case is handled above).
+ */
+ else if (ar == 0. && ai == 0.) {
+ /* If br > 0 then result is 0 but contains sign opposite
+ * of bi.
+ * Else the result is undefined and we return (nan,nan)
+ */
+ if ( br > 0) {
+ if (bi < 0) {
+ return npy_cpack@c@(0., 0.);
+ }
+ else {
+ return npy_cpack@c@(0., -0.);
+ }
+
+ }
+ else {
+ /* Raising an invalid and returning
+ * (nan, nan)
+ */
+ volatile @type@ tmp = NPY_INFINITY@C@;
+ r = npy_cpack@c@(NPY_NAN@C@, NPY_NAN@C@);
+
+ /* Raise invalid */
+ tmp -= NPY_INFINITY@C@;
+ ar = tmp;
+ return r;
+ }
}
if (bi == 0 && (n=(npy_intp)br) == br) {
if (n == 1) {
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 44a9abda3..e0b8577de 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -840,7 +840,7 @@ class TestPower:
# Testing 0^{Non-zero} issue 18378
def test_zero_power_nonzero(self):
- zero = np.array([0j])
+ zero = np.array([0.0j])
cnan = np.array([complex(np.nan, np.nan)])
def assert_complex_equal(x, y):
@@ -850,11 +850,18 @@ class TestPower:
#Complex powers with positive real part will not generate a warning
assert_complex_equal(np.power(zero, 1+4j), zero)
assert_complex_equal(np.power(zero, 2-3j), zero)
- #Complex powers will negative real part will generate a NAN
- #and hence a RUNTIME warning
+ #Testing zero values when real part is greater than zero
+ assert_complex_equal(np.power(zero, 1+1j), -zero)
+ assert_complex_equal(np.power(zero, 1+0j), -zero)
+ assert_complex_equal(np.power(zero, 1-1j), zero)
+ #Complex powers will negative real part or 0 (provided imaginary
+ # part is not zero) will generate a NAN and hence a RUNTIME warning
with pytest.warns(expected_warning=RuntimeWarning):
assert_complex_equal(np.power(zero, -1+1j), cnan)
assert_complex_equal(np.power(zero, -2-3j), cnan)
+ assert_complex_equal(np.power(zero, -7+0j), cnan)
+ assert_complex_equal(np.power(zero, 0+1j), cnan)
+ assert_complex_equal(np.power(zero, 0-1j), cnan)
def test_fast_power(self):
x = np.array([1, 2, 3], np.int16)