summaryrefslogtreecommitdiff
path: root/numpy/lib/scimath.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/scimath.py')
-rw-r--r--numpy/lib/scimath.py83
1 files changed, 60 insertions, 23 deletions
diff --git a/numpy/lib/scimath.py b/numpy/lib/scimath.py
index e07caf805..5ac790ce9 100644
--- a/numpy/lib/scimath.py
+++ b/numpy/lib/scimath.py
@@ -20,6 +20,7 @@ from __future__ import division, absolute_import, print_function
import numpy.core.numeric as nx
import numpy.core.numerictypes as nt
from numpy.core.numeric import asarray, any
+from numpy.core.overrides import array_function_dispatch
from numpy.lib.type_check import isreal
@@ -58,7 +59,7 @@ def _tocomplex(arr):
>>> a = np.array([1,2,3],np.short)
>>> ac = np.lib.scimath._tocomplex(a); ac
- array([ 1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
+ array([1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
>>> ac.dtype
dtype('complex64')
@@ -69,7 +70,7 @@ def _tocomplex(arr):
>>> b = np.array([1,2,3],np.double)
>>> bc = np.lib.scimath._tocomplex(b); bc
- array([ 1.+0.j, 2.+0.j, 3.+0.j])
+ array([1.+0.j, 2.+0.j, 3.+0.j])
>>> bc.dtype
dtype('complex128')
@@ -80,13 +81,13 @@ def _tocomplex(arr):
>>> c = np.array([1,2,3],np.csingle)
>>> cc = np.lib.scimath._tocomplex(c); cc
- array([ 1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
+ array([1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
>>> c *= 2; c
- array([ 2.+0.j, 4.+0.j, 6.+0.j], dtype=complex64)
+ array([2.+0.j, 4.+0.j, 6.+0.j], dtype=complex64)
>>> cc
- array([ 1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
+ array([1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
"""
if issubclass(arr.dtype.type, (nt.single, nt.byte, nt.short, nt.ubyte,
nt.ushort, nt.csingle)):
@@ -94,6 +95,7 @@ def _tocomplex(arr):
else:
return arr.astype(nt.cdouble)
+
def _fix_real_lt_zero(x):
"""Convert `x` to complex if it has real, negative components.
@@ -121,6 +123,7 @@ def _fix_real_lt_zero(x):
x = _tocomplex(x)
return x
+
def _fix_int_lt_zero(x):
"""Convert `x` to double if it has real, negative components.
@@ -147,6 +150,7 @@ def _fix_int_lt_zero(x):
x = x * 1.0
return x
+
def _fix_real_abs_gt_1(x):
"""Convert `x` to complex if it has real components x_i with abs(x_i)>1.
@@ -166,13 +170,19 @@ def _fix_real_abs_gt_1(x):
array([0, 1])
>>> np.lib.scimath._fix_real_abs_gt_1([0,2])
- array([ 0.+0.j, 2.+0.j])
+ array([0.+0.j, 2.+0.j])
"""
x = asarray(x)
if any(isreal(x) & (abs(x) > 1)):
x = _tocomplex(x)
return x
+
+def _unary_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_unary_dispatcher)
def sqrt(x):
"""
Compute the square root of x.
@@ -202,19 +212,21 @@ def sqrt(x):
>>> np.lib.scimath.sqrt(1)
1.0
>>> np.lib.scimath.sqrt([1, 4])
- array([ 1., 2.])
+ array([1., 2.])
But it automatically handles negative inputs:
>>> np.lib.scimath.sqrt(-1)
- (0.0+1.0j)
+ 1j
>>> np.lib.scimath.sqrt([-1,4])
- array([ 0.+1.j, 2.+0.j])
+ array([0.+1.j, 2.+0.j])
"""
x = _fix_real_lt_zero(x)
return nx.sqrt(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log(x):
"""
Compute the natural logarithm of `x`.
@@ -261,6 +273,8 @@ def log(x):
x = _fix_real_lt_zero(x)
return nx.log(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log10(x):
"""
Compute the logarithm base 10 of `x`.
@@ -303,12 +317,18 @@ def log10(x):
1.0
>>> np.emath.log10([-10**1, -10**2, 10**2])
- array([ 1.+1.3644j, 2.+1.3644j, 2.+0.j ])
+ array([1.+1.3644j, 2.+1.3644j, 2.+0.j ])
"""
x = _fix_real_lt_zero(x)
return nx.log10(x)
+
+def _logn_dispatcher(n, x):
+ return (n, x,)
+
+
+@array_function_dispatch(_logn_dispatcher)
def logn(n, x):
"""
Take log base n of x.
@@ -318,8 +338,8 @@ def logn(n, x):
Parameters
----------
- n : int
- The base in which the log is taken.
+ n : array_like
+ The integer base(s) in which the log is taken.
x : array_like
The value(s) whose log base `n` is (are) required.
@@ -334,15 +354,17 @@ def logn(n, x):
>>> np.set_printoptions(precision=4)
>>> np.lib.scimath.logn(2, [4, 8])
- array([ 2., 3.])
+ array([2., 3.])
>>> np.lib.scimath.logn(2, [-4, -8, 8])
- array([ 2.+4.5324j, 3.+4.5324j, 3.+0.j ])
+ array([2.+4.5324j, 3.+4.5324j, 3.+0.j ])
"""
x = _fix_real_lt_zero(x)
n = _fix_real_lt_zero(n)
return nx.log(x)/nx.log(n)
+
+@array_function_dispatch(_unary_dispatcher)
def log2(x):
"""
Compute the logarithm base 2 of `x`.
@@ -383,12 +405,18 @@ def log2(x):
>>> np.emath.log2(8)
3.0
>>> np.emath.log2([-4, -8, 8])
- array([ 2.+4.5324j, 3.+4.5324j, 3.+0.j ])
+ array([2.+4.5324j, 3.+4.5324j, 3.+0.j ])
"""
x = _fix_real_lt_zero(x)
return nx.log2(x)
+
+def _power_dispatcher(x, p):
+ return (x, p)
+
+
+@array_function_dispatch(_power_dispatcher)
def power(x, p):
"""
Return x to the power p, (x**p).
@@ -423,15 +451,17 @@ def power(x, p):
>>> np.lib.scimath.power([2, 4], 2)
array([ 4, 16])
>>> np.lib.scimath.power([2, 4], -2)
- array([ 0.25 , 0.0625])
+ array([0.25 , 0.0625])
>>> np.lib.scimath.power([-2, 4], 2)
- array([ 4.+0.j, 16.+0.j])
+ array([ 4.-0.j, 16.+0.j])
"""
x = _fix_real_lt_zero(x)
p = _fix_int_lt_zero(p)
return nx.power(x, p)
+
+@array_function_dispatch(_unary_dispatcher)
def arccos(x):
"""
Compute the inverse cosine of x.
@@ -469,12 +499,14 @@ def arccos(x):
0.0
>>> np.emath.arccos([1,2])
- array([ 0.-0.j , 0.+1.317j])
+ array([0.-0.j , 0.-1.317j])
"""
x = _fix_real_abs_gt_1(x)
return nx.arccos(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arcsin(x):
"""
Compute the inverse sine of x.
@@ -513,12 +545,14 @@ def arcsin(x):
0.0
>>> np.emath.arcsin([0,1])
- array([ 0. , 1.5708])
+ array([0. , 1.5708])
"""
x = _fix_real_abs_gt_1(x)
return nx.arcsin(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arctanh(x):
"""
Compute the inverse hyperbolic tangent of `x`.
@@ -555,11 +589,14 @@ def arctanh(x):
--------
>>> np.set_printoptions(precision=4)
- >>> np.emath.arctanh(np.matrix(np.eye(2)))
- array([[ Inf, 0.],
- [ 0., Inf]])
+ >>> from numpy.testing import suppress_warnings
+ >>> with suppress_warnings() as sup:
+ ... sup.filter(RuntimeWarning)
+ ... np.emath.arctanh(np.eye(2))
+ array([[inf, 0.],
+ [ 0., inf]])
>>> np.emath.arctanh([1j])
- array([ 0.+0.7854j])
+ array([0.+0.7854j])
"""
x = _fix_real_abs_gt_1(x)