summaryrefslogtreecommitdiff
path: root/numpy/lib/scimath.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/scimath.py')
-rw-r--r--numpy/lib/scimath.py38
1 files changed, 36 insertions, 2 deletions
diff --git a/numpy/lib/scimath.py b/numpy/lib/scimath.py
index f1838fee6..9ca006841 100644
--- a/numpy/lib/scimath.py
+++ b/numpy/lib/scimath.py
@@ -20,6 +20,7 @@ from __future__ import division, absolute_import, print_function
import numpy.core.numeric as nx
import numpy.core.numerictypes as nt
from numpy.core.numeric import asarray, any
+from numpy.core.overrides import array_function_dispatch
from numpy.lib.type_check import isreal
@@ -94,6 +95,7 @@ def _tocomplex(arr):
else:
return arr.astype(nt.cdouble)
+
def _fix_real_lt_zero(x):
"""Convert `x` to complex if it has real, negative components.
@@ -121,6 +123,7 @@ def _fix_real_lt_zero(x):
x = _tocomplex(x)
return x
+
def _fix_int_lt_zero(x):
"""Convert `x` to double if it has real, negative components.
@@ -147,6 +150,7 @@ def _fix_int_lt_zero(x):
x = x * 1.0
return x
+
def _fix_real_abs_gt_1(x):
"""Convert `x` to complex if it has real components x_i with abs(x_i)>1.
@@ -173,6 +177,12 @@ def _fix_real_abs_gt_1(x):
x = _tocomplex(x)
return x
+
+def _unary_dispatcher(x):
+ return (x,)
+
+
+@array_function_dispatch(_unary_dispatcher)
def sqrt(x):
"""
Compute the square root of x.
@@ -215,6 +225,8 @@ def sqrt(x):
x = _fix_real_lt_zero(x)
return nx.sqrt(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log(x):
"""
Compute the natural logarithm of `x`.
@@ -261,6 +273,8 @@ def log(x):
x = _fix_real_lt_zero(x)
return nx.log(x)
+
+@array_function_dispatch(_unary_dispatcher)
def log10(x):
"""
Compute the logarithm base 10 of `x`.
@@ -309,6 +323,12 @@ def log10(x):
x = _fix_real_lt_zero(x)
return nx.log10(x)
+
+def _logn_dispatcher(n, x):
+ return (n, x,)
+
+
+@array_function_dispatch(_logn_dispatcher)
def logn(n, x):
"""
Take log base n of x.
@@ -318,8 +338,8 @@ def logn(n, x):
Parameters
----------
- n : int
- The base in which the log is taken.
+ n : array_like
+ The integer base(s) in which the log is taken.
x : array_like
The value(s) whose log base `n` is (are) required.
@@ -343,6 +363,8 @@ def logn(n, x):
n = _fix_real_lt_zero(n)
return nx.log(x)/nx.log(n)
+
+@array_function_dispatch(_unary_dispatcher)
def log2(x):
"""
Compute the logarithm base 2 of `x`.
@@ -389,6 +411,12 @@ def log2(x):
x = _fix_real_lt_zero(x)
return nx.log2(x)
+
+def _power_dispatcher(x, p):
+ return (x, p)
+
+
+@array_function_dispatch(_power_dispatcher)
def power(x, p):
"""
Return x to the power p, (x**p).
@@ -432,6 +460,8 @@ def power(x, p):
p = _fix_int_lt_zero(p)
return nx.power(x, p)
+
+@array_function_dispatch(_unary_dispatcher)
def arccos(x):
"""
Compute the inverse cosine of x.
@@ -475,6 +505,8 @@ def arccos(x):
x = _fix_real_abs_gt_1(x)
return nx.arccos(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arcsin(x):
"""
Compute the inverse sine of x.
@@ -519,6 +551,8 @@ def arcsin(x):
x = _fix_real_abs_gt_1(x)
return nx.arcsin(x)
+
+@array_function_dispatch(_unary_dispatcher)
def arctanh(x):
"""
Compute the inverse hyperbolic tangent of `x`.