diff options
Diffstat (limited to 'numpy/lib/tests/test_nanfunctions.py')
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index b7261c63f..da2d0cc52 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -4,7 +4,7 @@ import warnings import pytest import numpy as np -from numpy.lib.nanfunctions import _nan_mask +from numpy.lib.nanfunctions import _nan_mask, _replace_nan from numpy.testing import ( assert_, assert_equal, assert_almost_equal, assert_no_warnings, assert_raises, assert_array_equal, suppress_warnings @@ -953,3 +953,30 @@ def test__nan_mask(arr, expected): # for types that can't possibly contain NaN if type(expected) is not np.ndarray: assert actual is True + + +def test__replace_nan(): + """ Test that _replace_nan returns the original array if there are no + NaNs, not a copy. + """ + for dtype in [np.bool, np.int32, np.int64]: + arr = np.array([0, 1], dtype=dtype) + result, mask = _replace_nan(arr, 0) + assert mask is None + # do not make a copy if there are no nans + assert result is arr + + for dtype in [np.float32, np.float64]: + arr = np.array([0, 1], dtype=dtype) + result, mask = _replace_nan(arr, 2) + assert (mask == False).all() + # mask is not None, so we make a copy + assert result is not arr + assert_equal(result, arr) + + arr_nan = np.array([0, 1, np.nan], dtype=dtype) + result_nan, mask_nan = _replace_nan(arr_nan, 2) + assert_equal(mask_nan, np.array([False, False, True])) + assert result_nan is not arr_nan + assert_equal(result_nan, np.array([0, 1, 2])) + assert np.isnan(arr_nan[-1]) |