summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/tests/test_nanfunctions.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 3fdeec41c..3b0fa6656 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -1,5 +1,6 @@
import warnings
import pytest
+import inspect
import numpy as np
from numpy.lib.nanfunctions import _nan_mask, _replace_nan
@@ -35,6 +36,53 @@ _ndat_zeros = np.array([[0.6244, 0.0, 0.2692, 0.0116, 0.0, 0.1170],
[0.1610, 0.0, 0.0, 0.1859, 0.3146, 0.0]])
+class TestSignatureMatch:
+ NANFUNCS = {
+ np.nanmin: np.amin,
+ np.nanmax: np.amax,
+ np.nanargmin: np.argmin,
+ np.nanargmax: np.argmax,
+ np.nansum: np.sum,
+ np.nanprod: np.prod,
+ np.nancumsum: np.cumsum,
+ np.nancumprod: np.cumprod,
+ np.nanmean: np.mean,
+ np.nanmedian: np.median,
+ np.nanpercentile: np.percentile,
+ np.nanquantile: np.quantile,
+ np.nanvar: np.var,
+ np.nanstd: np.std,
+ }
+ IDS = [k.__name__ for k in NANFUNCS]
+
+ @staticmethod
+ def get_signature(func, default="..."):
+ """Construct a signature and replace all default parameter-values."""
+ prm_list = []
+ signature = inspect.signature(func)
+ for prm in signature.parameters.values():
+ if prm.default is inspect.Parameter.empty:
+ prm_list.append(prm)
+ else:
+ prm_list.append(prm.replace(default=default))
+ return inspect.Signature(prm_list)
+
+ @pytest.mark.parametrize("nan_func,func", NANFUNCS.items(), ids=IDS)
+ def test_signature_match(self, nan_func, func):
+ # Ignore the default parameter-values as they can sometimes differ
+ # between the two functions (*e.g.* one has `False` while the other
+ # has `np._NoValue`)
+ signature = self.get_signature(func)
+ nan_signature = self.get_signature(nan_func)
+ np.testing.assert_equal(signature, nan_signature)
+
+ def test_exhaustiveness(self):
+ """Validate that all nan functions are actually tested."""
+ np.testing.assert_equal(
+ set(self.IDS), set(np.lib.nanfunctions.__all__)
+ )
+
+
class TestNanFunctions_MinMax:
nanfuncs = [np.nanmin, np.nanmax]