summaryrefslogtreecommitdiff
path: root/numpy/lib/tests
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-10-04 15:38:40 +0200
committerBas van Beek <43369155+BvB93@users.noreply.github.com>2021-10-04 19:04:23 +0200
commit27710a458c70298f5c2cb0ffefce0737e43aeaed (patch)
tree77c144ed8d6ca65998dfc7ad8f649c3c0bbbe7d7 /numpy/lib/tests
parente4444e5a0c1f11a5a1aab18fcbdde3da059d7f4d (diff)
downloadnumpy-27710a458c70298f5c2cb0ffefce0737e43aeaed.tar.gz
TST: Add a test for comparing the signatures of `nan<x>` and `<x>` functions
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r--numpy/lib/tests/test_nanfunctions.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 3fdeec41c..3b0fa6656 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -1,5 +1,6 @@
import warnings
import pytest
+import inspect
import numpy as np
from numpy.lib.nanfunctions import _nan_mask, _replace_nan
@@ -35,6 +36,53 @@ _ndat_zeros = np.array([[0.6244, 0.0, 0.2692, 0.0116, 0.0, 0.1170],
[0.1610, 0.0, 0.0, 0.1859, 0.3146, 0.0]])
+class TestSignatureMatch:
+ NANFUNCS = {
+ np.nanmin: np.amin,
+ np.nanmax: np.amax,
+ np.nanargmin: np.argmin,
+ np.nanargmax: np.argmax,
+ np.nansum: np.sum,
+ np.nanprod: np.prod,
+ np.nancumsum: np.cumsum,
+ np.nancumprod: np.cumprod,
+ np.nanmean: np.mean,
+ np.nanmedian: np.median,
+ np.nanpercentile: np.percentile,
+ np.nanquantile: np.quantile,
+ np.nanvar: np.var,
+ np.nanstd: np.std,
+ }
+ IDS = [k.__name__ for k in NANFUNCS]
+
+ @staticmethod
+ def get_signature(func, default="..."):
+ """Construct a signature and replace all default parameter-values."""
+ prm_list = []
+ signature = inspect.signature(func)
+ for prm in signature.parameters.values():
+ if prm.default is inspect.Parameter.empty:
+ prm_list.append(prm)
+ else:
+ prm_list.append(prm.replace(default=default))
+ return inspect.Signature(prm_list)
+
+ @pytest.mark.parametrize("nan_func,func", NANFUNCS.items(), ids=IDS)
+ def test_signature_match(self, nan_func, func):
+ # Ignore the default parameter-values as they can sometimes differ
+ # between the two functions (*e.g.* one has `False` while the other
+ # has `np._NoValue`)
+ signature = self.get_signature(func)
+ nan_signature = self.get_signature(nan_func)
+ np.testing.assert_equal(signature, nan_signature)
+
+ def test_exhaustiveness(self):
+ """Validate that all nan functions are actually tested."""
+ np.testing.assert_equal(
+ set(self.IDS), set(np.lib.nanfunctions.__all__)
+ )
+
+
class TestNanFunctions_MinMax:
nanfuncs = [np.nanmin, np.nanmax]