summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2019-03-12 22:45:16 -0700
committerEric Wieser <wieser.eric@gmail.com>2019-03-12 23:13:59 -0700
commit43c79ff448534e1d672e5c6013f9659d27d69aa0 (patch)
tree89d9b63028478412663a8ecbe881a7783e681516 /numpy
parenta9790fe223a15419c68aa1dd6ee6ab45ad4b96c8 (diff)
downloadnumpy-43c79ff448534e1d672e5c6013f9659d27d69aa0.tar.gz
MAINT: Unify polynomial division functions
These division functions are all the same - the algorithm used does not care about the basis. Note that while chebdiv and polydiv could be implemented in terms of this function, their current implementations are more optimal and exploit the properties of a multiplication by a basis polynomial.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/polynomial/chebyshev.py1
-rw-r--r--numpy/polynomial/hermite.py21
-rw-r--r--numpy/polynomial/hermite_e.py21
-rw-r--r--numpy/polynomial/laguerre.py21
-rw-r--r--numpy/polynomial/legendre.py21
-rw-r--r--numpy/polynomial/polynomial.py1
-rw-r--r--numpy/polynomial/polyutils.py36
7 files changed, 42 insertions, 80 deletions
diff --git a/numpy/polynomial/chebyshev.py b/numpy/polynomial/chebyshev.py
index fd05280e9..0eef90177 100644
--- a/numpy/polynomial/chebyshev.py
+++ b/numpy/polynomial/chebyshev.py
@@ -793,6 +793,7 @@ def chebdiv(c1, c2):
if c2[-1] == 0:
raise ZeroDivisionError()
+ # note: this is more efficient than `pu._div(chebmul, c1, c2)`
lc1 = len(c1)
lc2 = len(c2)
if lc1 < lc2:
diff --git a/numpy/polynomial/hermite.py b/numpy/polynomial/hermite.py
index 605cb29ad..3767a80fc 100644
--- a/numpy/polynomial/hermite.py
+++ b/numpy/polynomial/hermite.py
@@ -550,26 +550,7 @@ def hermdiv(c1, c2):
(array([1., 2., 3.]), array([1., 1.]))
"""
- # c1, c2 are trimmed copies
- [c1, c2] = pu.as_series([c1, c2])
- if c2[-1] == 0:
- raise ZeroDivisionError()
-
- lc1 = len(c1)
- lc2 = len(c2)
- if lc1 < lc2:
- return c1[:1]*0, c1
- elif lc2 == 1:
- return c1/c2[-1], c1[:1]*0
- else:
- quo = np.empty(lc1 - lc2 + 1, dtype=c1.dtype)
- rem = c1
- for i in range(lc1 - lc2, - 1, -1):
- p = hermmul([0]*i + [1], c2)
- q = rem[-1]/p[-1]
- rem = rem[:-1] - q*p[:-1]
- quo[i] = q
- return quo, pu.trimseq(rem)
+ return pu._div(hermmul, c1, c2)
def hermpow(c, pow, maxpower=16):
diff --git a/numpy/polynomial/hermite_e.py b/numpy/polynomial/hermite_e.py
index b28881013..228396457 100644
--- a/numpy/polynomial/hermite_e.py
+++ b/numpy/polynomial/hermite_e.py
@@ -545,26 +545,7 @@ def hermediv(c1, c2):
(array([1., 2., 3.]), array([1., 2.]))
"""
- # c1, c2 are trimmed copies
- [c1, c2] = pu.as_series([c1, c2])
- if c2[-1] == 0:
- raise ZeroDivisionError()
-
- lc1 = len(c1)
- lc2 = len(c2)
- if lc1 < lc2:
- return c1[:1]*0, c1
- elif lc2 == 1:
- return c1/c2[-1], c1[:1]*0
- else:
- quo = np.empty(lc1 - lc2 + 1, dtype=c1.dtype)
- rem = c1
- for i in range(lc1 - lc2, - 1, -1):
- p = hermemul([0]*i + [1], c2)
- q = rem[-1]/p[-1]
- rem = rem[:-1] - q*p[:-1]
- quo[i] = q
- return quo, pu.trimseq(rem)
+ return pu._div(hermemul, c1, c2)
def hermepow(c, pow, maxpower=16):
diff --git a/numpy/polynomial/laguerre.py b/numpy/polynomial/laguerre.py
index 575c5b2bc..dec469c17 100644
--- a/numpy/polynomial/laguerre.py
+++ b/numpy/polynomial/laguerre.py
@@ -547,26 +547,7 @@ def lagdiv(c1, c2):
(array([1., 2., 3.]), array([1., 1.]))
"""
- # c1, c2 are trimmed copies
- [c1, c2] = pu.as_series([c1, c2])
- if c2[-1] == 0:
- raise ZeroDivisionError()
-
- lc1 = len(c1)
- lc2 = len(c2)
- if lc1 < lc2:
- return c1[:1]*0, c1
- elif lc2 == 1:
- return c1/c2[-1], c1[:1]*0
- else:
- quo = np.empty(lc1 - lc2 + 1, dtype=c1.dtype)
- rem = c1
- for i in range(lc1 - lc2, - 1, -1):
- p = lagmul([0]*i + [1], c2)
- q = rem[-1]/p[-1]
- rem = rem[:-1] - q*p[:-1]
- quo[i] = q
- return quo, pu.trimseq(rem)
+ return pu._div(lagmul, c1, c2)
def lagpow(c, pow, maxpower=16):
diff --git a/numpy/polynomial/legendre.py b/numpy/polynomial/legendre.py
index 6cd4360da..5f8ce53a3 100644
--- a/numpy/polynomial/legendre.py
+++ b/numpy/polynomial/legendre.py
@@ -590,26 +590,7 @@ def legdiv(c1, c2):
(array([-0.07407407, 1.66666667]), array([-1.03703704, -2.51851852])) # may vary
"""
- # c1, c2 are trimmed copies
- [c1, c2] = pu.as_series([c1, c2])
- if c2[-1] == 0:
- raise ZeroDivisionError()
-
- lc1 = len(c1)
- lc2 = len(c2)
- if lc1 < lc2:
- return c1[:1]*0, c1
- elif lc2 == 1:
- return c1/c2[-1], c1[:1]*0
- else:
- quo = np.empty(lc1 - lc2 + 1, dtype=c1.dtype)
- rem = c1
- for i in range(lc1 - lc2, - 1, -1):
- p = legmul([0]*i + [1], c2)
- q = rem[-1]/p[-1]
- rem = rem[:-1] - q*p[:-1]
- quo[i] = q
- return quo, pu.trimseq(rem)
+ return pu._div(legmul, c1, c2)
def legpow(c, pow, maxpower=16):
diff --git a/numpy/polynomial/polynomial.py b/numpy/polynomial/polynomial.py
index 8c6b604d5..f63d9dd74 100644
--- a/numpy/polynomial/polynomial.py
+++ b/numpy/polynomial/polynomial.py
@@ -397,6 +397,7 @@ def polydiv(c1, c2):
if c2[-1] == 0:
raise ZeroDivisionError()
+ # note: this is more efficient than `pu._div(polymul, c1, c2)`
lc1 = len(c1)
lc2 = len(c2)
if lc1 < lc2:
diff --git a/numpy/polynomial/polyutils.py b/numpy/polynomial/polyutils.py
index db1cb2841..1d5390984 100644
--- a/numpy/polynomial/polyutils.py
+++ b/numpy/polynomial/polyutils.py
@@ -537,3 +537,39 @@ def _gridnd(val_f, c, *args):
for xi in args:
c = val_f(xi, c)
return c
+
+
+def _div(mul_f, c1, c2):
+ """
+ Helper function used to implement the ``<type>div`` functions.
+
+ Implementation uses repeated subtraction of c2 multiplied by the nth basis.
+ For some polynomial types, a more efficient approach may be possible.
+
+ Parameters
+ ----------
+ mul_f : function(array_like, array_like) -> array_like
+ The ``<type>mul`` function, such as ``polymul``
+ c1, c2 :
+ See the ``<type>div`` functions for more detail
+ """
+ # c1, c2 are trimmed copies
+ [c1, c2] = as_series([c1, c2])
+ if c2[-1] == 0:
+ raise ZeroDivisionError()
+
+ lc1 = len(c1)
+ lc2 = len(c2)
+ if lc1 < lc2:
+ return c1[:1]*0, c1
+ elif lc2 == 1:
+ return c1/c2[-1], c1[:1]*0
+ else:
+ quo = np.empty(lc1 - lc2 + 1, dtype=c1.dtype)
+ rem = c1
+ for i in range(lc1 - lc2, - 1, -1):
+ p = mul_f([0]*i + [1], c2)
+ q = rem[-1]/p[-1]
+ rem = rem[:-1] - q*p[:-1]
+ quo[i] = q
+ return quo, trimseq(rem)