diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2019-03-11 22:42:04 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2019-03-12 20:38:15 -0700 |
commit | fcea19a3dd586bbf9d62719de551ac75d3b4e17a (patch) | |
tree | 22fd41d0166b1a7dbe887ef3b09db9d5d5ad3f72 /numpy/polynomial/polyutils.py | |
parent | 1bb279ad4f25d987155106ee6f82ba7fc83ce5a0 (diff) | |
download | numpy-fcea19a3dd586bbf9d62719de551ac75d3b4e17a.tar.gz |
MAINT: Unify polynomial valnd functions
No point writing the same function 12 times, when you can write it once
Diffstat (limited to 'numpy/polynomial/polyutils.py')
-rw-r--r-- | numpy/polynomial/polyutils.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/numpy/polynomial/polyutils.py b/numpy/polynomial/polyutils.py index 9482ed89f..db1cb2841 100644 --- a/numpy/polynomial/polyutils.py +++ b/numpy/polynomial/polyutils.py @@ -489,3 +489,51 @@ def _fromroots(line_f, mul_f, roots): p = tmp n = m return p[0] + + +def _valnd(val_f, c, *args): + """ + Helper function used to implement the ``<type>val<n>d`` functions. + + Parameters + ---------- + val_f : function(array_like, array_like, tensor: bool) -> array_like + The ``<type>val`` function, such as ``polyval`` + c, args : + See the ``<type>val<n>d`` functions for more detail + """ + try: + args = tuple(np.array(args, copy=False)) + except Exception: + # preserve the old error message + if len(args) == 2: + raise ValueError('x, y, z are incompatible') + elif len(args) == 3: + raise ValueError('x, y are incompatible') + else: + raise ValueError('ordinates are incompatible') + + it = iter(args) + x0 = next(it) + + # use tensor on only the first + c = val_f(x0, c) + for xi in it: + c = val_f(xi, c, tensor=False) + return c + + +def _gridnd(val_f, c, *args): + """ + Helper function used to implement the ``<type>grid<n>d`` functions. + + Parameters + ---------- + val_f : function(array_like, array_like, tensor: bool) -> array_like + The ``<type>val`` function, such as ``polyval`` + c, args : + See the ``<type>grid<n>d`` functions for more detail + """ + for xi in args: + c = val_f(xi, c) + return c |