From 73151451437fa6ce0d8b5f033c1e005885f63cf8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 22 Oct 2018 17:40:08 -0700 Subject: ENH: __array_function__ support for np.lib, part 2/2 (#12119) * ENH: __array_function__ support for np.lib, part 2 xref GH12028 np.lib.npyio through np.lib.ufunclike * Fix failures in numpy/core/tests/test_overrides.py * CLN: handle depreaction in dispatchers for np.lib.ufunclike * CLN: fewer dispatchers in lib.twodim_base * CLN: fewer dispatchers in lib.shape_base * CLN: more dispatcher consolidation * BUG: fix test failure * Use all method instead of function in assert_equal * DOC: indicate n is array_like in scimath.logn * MAINT: updates per review * MAINT: more conservative changes in assert_array_equal * MAINT: add back in comment * MAINT: casting tweaks in assert_array_equal * MAINT: fixes and tests for assert_array_equal on subclasses --- numpy/lib/polynomial.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) (limited to 'numpy/lib/polynomial.py') diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py index 9f3b84732..09079db9d 100644 --- a/numpy/lib/polynomial.py +++ b/numpy/lib/polynomial.py @@ -14,11 +14,13 @@ import numpy.core.numeric as NX from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array, ones) +from numpy.core.overrides import array_function_dispatch from numpy.lib.twodim_base import diag, vander from numpy.lib.function_base import trim_zeros from numpy.lib.type_check import iscomplex, real, imag, mintypecode from numpy.linalg import eigvals, lstsq, inv + class RankWarning(UserWarning): """ Issued by `polyfit` when the Vandermonde matrix is rank deficient. @@ -29,6 +31,12 @@ class RankWarning(UserWarning): """ pass + +def _poly_dispatcher(seq_of_zeros): + return seq_of_zeros + + +@array_function_dispatch(_poly_dispatcher) def poly(seq_of_zeros): """ Find the coefficients of a polynomial with the given sequence of roots. @@ -145,6 +153,12 @@ def poly(seq_of_zeros): return a + +def _roots_dispatcher(p): + return p + + +@array_function_dispatch(_roots_dispatcher) def roots(p): """ Return the roots of a polynomial with coefficients given in p. @@ -229,6 +243,12 @@ def roots(p): roots = hstack((roots, NX.zeros(trailing_zeros, roots.dtype))) return roots + +def _polyint_dispatcher(p, m=None, k=None): + return (p,) + + +@array_function_dispatch(_polyint_dispatcher) def polyint(p, m=1, k=None): """ Return an antiderivative (indefinite integral) of a polynomial. @@ -322,6 +342,12 @@ def polyint(p, m=1, k=None): return poly1d(val) return val + +def _polyder_dispatcher(p, m=None): + return (p,) + + +@array_function_dispatch(_polyder_dispatcher) def polyder(p, m=1): """ Return the derivative of the specified order of a polynomial. @@ -390,6 +416,12 @@ def polyder(p, m=1): val = poly1d(val) return val + +def _polyfit_dispatcher(x, y, deg, rcond=None, full=None, w=None, cov=None): + return (x, y, w) + + +@array_function_dispatch(_polyfit_dispatcher) def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): """ Least squares polynomial fit. @@ -610,6 +642,11 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): return c +def _polyval_dispatcher(p, x): + return (p, x) + + +@array_function_dispatch(_polyval_dispatcher) def polyval(p, x): """ Evaluate a polynomial at specific values. @@ -679,6 +716,12 @@ def polyval(p, x): y = y * x + p[i] return y + +def _binary_op_dispatcher(a1, a2): + return (a1, a2) + + +@array_function_dispatch(_binary_op_dispatcher) def polyadd(a1, a2): """ Find the sum of two polynomials. @@ -739,6 +782,8 @@ def polyadd(a1, a2): val = poly1d(val) return val + +@array_function_dispatch(_binary_op_dispatcher) def polysub(a1, a2): """ Difference (subtraction) of two polynomials. @@ -786,6 +831,7 @@ def polysub(a1, a2): return val +@array_function_dispatch(_binary_op_dispatcher) def polymul(a1, a2): """ Find the product of two polynomials. @@ -842,6 +888,12 @@ def polymul(a1, a2): val = poly1d(val) return val + +def _polydiv_dispatcher(u, v): + return (u, v) + + +@array_function_dispatch(_polydiv_dispatcher) def polydiv(u, v): """ Returns the quotient and remainder of polynomial division. -- cgit v1.2.1