diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/arraypad.py | 2 | ||||
-rw-r--r-- | numpy/lib/arraysetops.py | 89 | ||||
-rw-r--r-- | numpy/lib/format.py | 7 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 5 | ||||
-rw-r--r-- | numpy/lib/nanfunctions.py | 28 | ||||
-rw-r--r-- | numpy/lib/npyio.py | 54 | ||||
-rw-r--r-- | numpy/lib/polynomial.py | 2 | ||||
-rw-r--r-- | numpy/lib/tests/test_arraysetops.py | 12 | ||||
-rw-r--r-- | numpy/lib/tests/test_format.py | 12 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 2 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 15 | ||||
-rw-r--r-- | numpy/lib/tests/test_polynomial.py | 8 | ||||
-rw-r--r-- | numpy/lib/tests/test_type_check.py | 18 | ||||
-rw-r--r-- | numpy/lib/type_check.py | 23 | ||||
-rw-r--r-- | numpy/lib/utils.py | 2 |
15 files changed, 166 insertions, 113 deletions
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py index 153b4af65..cdc354a02 100644 --- a/numpy/lib/arraypad.py +++ b/numpy/lib/arraypad.py @@ -1186,7 +1186,7 @@ def pad(array, pad_width, mode, **kwargs): reflect_type : {'even', 'odd'}, optional Used in 'reflect', and 'symmetric'. The 'even' style is the default with an unaltered reflection around the edge value. For - the 'odd' style, the extented part of the array is created by + the 'odd' style, the extended part of the array is created by subtracting the reflected values from two times the edge value. Returns diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index a9426cdf3..7b103ef3e 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -110,16 +110,25 @@ def ediff1d(ary, to_end=None, to_begin=None): return result +def _unpack_tuple(x): + """ Unpacks one-element tuples for use as return values """ + if len(x) == 1: + return x[0] + else: + return x + + def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ Find the unique elements of an array. Returns the sorted unique elements of an array. There are three optional - outputs in addition to the unique elements: the indices of the input array - that give the unique values, the indices of the unique array that - reconstruct the input array, and the number of times each unique value - comes up in the input array. + outputs in addition to the unique elements: + + * the indices of the input array that give the unique values + * the indices of the unique array that reconstruct the input array + * the number of times each unique value comes up in the input array Parameters ---------- @@ -135,13 +144,16 @@ def unique(ar, return_index=False, return_inverse=False, return_counts : bool, optional If True, also return the number of times each unique item appears in `ar`. + .. versionadded:: 1.9.0 + axis : int or None, optional The axis to operate on. If None, `ar` will be flattened beforehand. Otherwise, duplicate items will be removed along the provided axis, with all the other axes belonging to the each of the unique elements. Object arrays or structured arrays that contain objects are not supported if the `axis` kwarg is used. + .. versionadded:: 1.13.0 @@ -159,6 +171,7 @@ def unique(ar, return_index=False, return_inverse=False, unique_counts : ndarray, optional The number of times each of the unique values comes up in the original array. Only provided if `return_counts` is True. + .. versionadded:: 1.9.0 See Also @@ -207,11 +220,15 @@ def unique(ar, return_index=False, return_inverse=False, """ ar = np.asanyarray(ar) if axis is None: - return _unique1d(ar, return_index, return_inverse, return_counts) - if not (-ar.ndim <= axis < ar.ndim): - raise ValueError('Invalid axis kwarg specified for unique') + ret = _unique1d(ar, return_index, return_inverse, return_counts) + return _unpack_tuple(ret) + + try: + ar = np.swapaxes(ar, axis, 0) + except np.AxisError: + # this removes the "axis1" or "axis2" prefix from the error message + raise np.AxisError(axis, ar.ndim) - ar = np.swapaxes(ar, axis, 0) orig_shape, orig_dtype = ar.shape, ar.dtype # Must reshape to a contiguous 2D array for this to work... ar = ar.reshape(orig_shape[0], -1) @@ -241,11 +258,9 @@ def unique(ar, return_index=False, return_inverse=False, output = _unique1d(consolidated, return_index, return_inverse, return_counts) - if not (return_index or return_inverse or return_counts): - return reshape_uniq(output) - else: - uniq = reshape_uniq(output[0]) - return (uniq,) + output[1:] + output = (reshape_uniq(output[0]),) + output[1:] + return _unpack_tuple(output) + def _unique1d(ar, return_index=False, return_inverse=False, return_counts=False): @@ -255,20 +270,6 @@ def _unique1d(ar, return_index=False, return_inverse=False, ar = np.asanyarray(ar).flatten() optional_indices = return_index or return_inverse - optional_returns = optional_indices or return_counts - - if ar.size == 0: - if not optional_returns: - ret = ar - else: - ret = (ar,) - if return_index: - ret += (np.empty(0, np.intp),) - if return_inverse: - ret += (np.empty(0, np.intp),) - if return_counts: - ret += (np.empty(0, np.intp),) - return ret if optional_indices: perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') @@ -276,24 +277,24 @@ def _unique1d(ar, return_index=False, return_inverse=False, else: ar.sort() aux = ar - flag = np.concatenate(([True], aux[1:] != aux[:-1])) - - if not optional_returns: - ret = aux[flag] - else: - ret = (aux[flag],) - if return_index: - ret += (perm[flag],) - if return_inverse: - iflag = np.cumsum(flag) - 1 - inv_idx = np.empty(ar.shape, dtype=np.intp) - inv_idx[perm] = iflag - ret += (inv_idx,) - if return_counts: - idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) - ret += (np.diff(idx),) + mask = np.empty(aux.shape, dtype=np.bool_) + mask[:1] = True + mask[1:] = aux[1:] != aux[:-1] + + ret = (aux[mask],) + if return_index: + ret += (perm[mask],) + if return_inverse: + imask = np.cumsum(mask) - 1 + inv_idx = np.empty(mask.shape, dtype=np.intp) + inv_idx[perm] = imask + ret += (inv_idx,) + if return_counts: + idx = np.concatenate(np.nonzero(mask) + ([mask.size],)) + ret += (np.diff(idx),) return ret + def intersect1d(ar1, ar2, assume_unique=False): """ Find the intersection of two arrays. @@ -614,7 +615,7 @@ def union1d(ar1, ar2): >>> reduce(np.union1d, ([1, 3, 4, 3], [3, 1, 2, 1], [6, 3, 4, 2])) array([1, 2, 3, 4, 6]) """ - return unique(np.concatenate((ar1, ar2))) + return unique(np.concatenate((ar1, ar2), axis=None)) def setdiff1d(ar1, ar2, assume_unique=False): """ diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 84af2afc8..363bb2101 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -454,7 +454,9 @@ def _filter_header(s): tokens = [] last_token_was_number = False - for token in tokenize.generate_tokens(StringIO(asstr(s)).read): + # adding newline as python 2.7.5 workaround + string = asstr(s) + "\n" + for token in tokenize.generate_tokens(StringIO(string).readline): token_type = token[0] token_string = token[1] if (last_token_was_number and @@ -464,7 +466,8 @@ def _filter_header(s): else: tokens.append(token) last_token_was_number = (token_type == tokenize.NUMBER) - return tokenize.untokenize(tokens) + # removing newline (see above) as python 2.7.5 workaround + return tokenize.untokenize(tokens)[:-1] def _read_array_header(fp, version): diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 979bc13d0..391c47a06 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1236,7 +1236,8 @@ def interp(x, xp, fp, left=None, right=None, period=None): >>> np.interp(x, xp, fp, period=360) array([7.5, 5., 8.75, 6.25, 3., 3.25, 3.5, 3.75]) - Complex interpolation + Complex interpolation: + >>> x = [1.5, 4.0] >>> xp = [2,3,5] >>> fp = [1.0j, 0, 2+3j] @@ -3528,7 +3529,7 @@ def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False, else: zerod = False - # prepare a for partioning + # prepare a for partitioning if overwrite_input: if axis is None: ap = a.ravel() diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index 4b2c9d817..16e363d7c 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -198,8 +198,8 @@ def nanmin(a, axis=None, out=None, keepdims=np._NoValue): a : array_like Array containing numbers whose minimum is desired. If `a` is not an array, a conversion is attempted. - axis : int, optional - Axis along which the minimum is computed. The default is to compute + axis : {int, tuple of int, None}, optional + Axis or axes along which the minimum is computed. The default is to compute the minimum of the flattened array. out : ndarray, optional Alternate output array in which to place the result. The default @@ -306,8 +306,8 @@ def nanmax(a, axis=None, out=None, keepdims=np._NoValue): a : array_like Array containing numbers whose maximum is desired. If `a` is not an array, a conversion is attempted. - axis : int, optional - Axis along which the maximum is computed. The default is to compute + axis : {int, tuple of int, None}, optional + Axis or axes along which the maximum is computed. The default is to compute the maximum of the flattened array. out : ndarray, optional Alternate output array in which to place the result. The default @@ -505,8 +505,8 @@ def nansum(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): a : array_like Array containing numbers whose sum is desired. If `a` is not an array, a conversion is attempted. - axis : int, optional - Axis along which the sum is computed. The default is to compute the + axis : {int, tuple of int, None}, optional + Axis or axes along which the sum is computed. The default is to compute the sum of the flattened array. dtype : data-type, optional The type of the returned array and of the accumulator in which the @@ -596,8 +596,8 @@ def nanprod(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): a : array_like Array containing numbers whose product is desired. If `a` is not an array, a conversion is attempted. - axis : int, optional - Axis along which the product is computed. The default is to compute + axis : {int, tuple of int, None}, optional + Axis or axes along which the product is computed. The default is to compute the product of the flattened array. dtype : data-type, optional The type of the returned array and of the accumulator in which the @@ -791,8 +791,8 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): a : array_like Array containing numbers whose mean is desired. If `a` is not an array, a conversion is attempted. - axis : int, optional - Axis along which the means are computed. The default is to compute + axis : {int, tuple of int, None}, optional + Axis or axes along which the means are computed. The default is to compute the mean of the flattened array. dtype : data-type, optional Type to use in computing the mean. For integer inputs, the default @@ -1217,8 +1217,8 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): a : array_like Array containing numbers whose variance is desired. If `a` is not an array, a conversion is attempted. - axis : int, optional - Axis along which the variance is computed. The default is to compute + axis : {int, tuple of int, None}, optional + Axis or axes along which the variance is computed. The default is to compute the variance of the flattened array. dtype : data-type, optional Type to use in computing the variance. For arrays of integer type @@ -1359,8 +1359,8 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): ---------- a : array_like Calculate the standard deviation of the non-NaN values. - axis : int, optional - Axis along which the standard deviation is computed. The default is + axis : {int, tuple of int, None}, optional + Axis or axes along which the standard deviation is computed. The default is to compute the standard deviation of the flattened array. dtype : dtype, optional Type to use in computing the standard deviation. For arrays of diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 9ee0aaaae..096f1a3a4 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -797,22 +797,23 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, The string used to separate values. For backwards compatibility, byte strings will be decoded as 'latin1'. The default is whitespace. converters : dict, optional - A dictionary mapping column number to a function that will convert - that column to a float. E.g., if column 0 is a date string: - ``converters = {0: datestr2num}``. Converters can also be used to - provide a default value for missing data (but see also `genfromtxt`): - ``converters = {3: lambda s: float(s.strip() or 0)}``. Default: None. + A dictionary mapping column number to a function that will parse the + column string into the desired value. E.g., if column 0 is a date + string: ``converters = {0: datestr2num}``. Converters can also be + used to provide a default value for missing data (but see also + `genfromtxt`): ``converters = {3: lambda s: float(s.strip() or 0)}``. + Default: None. skiprows : int, optional Skip the first `skiprows` lines; default: 0. usecols : int or sequence, optional Which columns to read, with 0 being the first. For example, - usecols = (1,4,5) will extract the 2nd, 5th and 6th columns. + ``usecols = (1,4,5)`` will extract the 2nd, 5th and 6th columns. The default, None, results in all columns being read. .. versionchanged:: 1.11.0 When a single column has to be read it is possible to use an integer instead of a tuple. E.g ``usecols = 3`` reads the - fourth column the same way as `usecols = (3,)`` would. + fourth column the same way as ``usecols = (3,)`` would. unpack : bool, optional If True, the returned array is transposed, so that arguments may be unpacked using ``x, y, z = loadtxt(...)``. When used with a structured @@ -827,7 +828,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, Encoding used to decode the inputfile. Does not apply to input streams. The special value 'bytes' enables backward compatibility workarounds that ensures you receive byte arrays as results if possible and passes - latin1 encoded strings to converters. Override this value to receive + 'latin1' encoded strings to converters. Override this value to receive unicode arrays and pass strings as input to converters. If set to None the system default is used. The default value is 'bytes'. @@ -2049,7 +2050,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, strcolidx = [i for (i, v) in enumerate(column_types) if v == np.unicode_] - type_str = np.unicode_ if byte_converters and strcolidx: # convert strings back to bytes for backward compatibility warnings.warn( @@ -2065,33 +2065,37 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, try: data = [encode_unicode_cols(r) for r in data] - type_str = np.bytes_ except UnicodeEncodeError: pass + else: + for i in strcolidx: + column_types[i] = np.bytes_ + # Update string types to be the right length + sized_column_types = column_types[:] + for i, col_type in enumerate(column_types): + if np.issubdtype(col_type, np.character): + n_chars = max(len(row[i]) for row in data) + sized_column_types[i] = (col_type, n_chars) - # ... and take the largest number of chars. - for i in strcolidx: - max_line_length = max(len(row[i]) for row in data) - column_types[i] = np.dtype((type_str, max_line_length)) - # if names is None: - # If the dtype is uniform, don't define names, else use '' - base = set([c.type for c in converters if c._checked]) + # If the dtype is uniform (before sizing strings) + base = set([ + c_type + for c, c_type in zip(converters, column_types) + if c._checked]) if len(base) == 1: - if strcolidx: - (ddtype, mdtype) = (type_str, bool) - else: - (ddtype, mdtype) = (list(base)[0], bool) + uniform_type, = base + (ddtype, mdtype) = (uniform_type, bool) else: ddtype = [(defaultfmt % i, dt) - for (i, dt) in enumerate(column_types)] + for (i, dt) in enumerate(sized_column_types)] if usemask: mdtype = [(defaultfmt % i, bool) - for (i, dt) in enumerate(column_types)] + for (i, dt) in enumerate(sized_column_types)] else: - ddtype = list(zip(names, column_types)) - mdtype = list(zip(names, [bool] * len(column_types))) + ddtype = list(zip(names, sized_column_types)) + mdtype = list(zip(names, [bool] * len(sized_column_types))) output = np.array(data, dtype=ddtype) if usemask: outputmask = np.array(masks, dtype=mdtype) diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py index f49b7e295..41b5e2f64 100644 --- a/numpy/lib/polynomial.py +++ b/numpy/lib/polynomial.py @@ -897,7 +897,7 @@ def polydiv(u, v): n = len(v) - 1 scale = 1. / v[0] q = NX.zeros((max(m - n + 1, 1),), w.dtype) - r = u.copy() + r = u.astype(w.dtype) for k in range(0, m-n+1): d = scale * r[k] q[k] = d diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index b4787838d..17415d8fe 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -247,6 +247,14 @@ class TestSetOps(object): c = union1d(a, b) assert_array_equal(c, ec) + # Tests gh-10340, arguments to union1d should be + # flattened if they are not already 1D + x = np.array([[0, 1, 2], [3, 4, 5]]) + y = np.array([0, 1, 2, 3, 4]) + ez = np.array([0, 1, 2, 3, 4, 5]) + z = union1d(x, y) + assert_array_equal(z, ez) + assert_array_equal([], union1d([], [])) def test_setdiff1d(self): @@ -401,8 +409,8 @@ class TestUnique(object): assert_raises(TypeError, self._run_axis_tests, [('a', int), ('b', object)]) - assert_raises(ValueError, unique, np.arange(10), axis=2) - assert_raises(ValueError, unique, np.arange(10), axis=-2) + assert_raises(np.AxisError, unique, np.arange(10), axis=2) + assert_raises(np.AxisError, unique, np.arange(10), axis=-2) def test_unique_axis_list(self): msg = "Unique failed on list of lists" diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index 2d2b4cea2..d3bd2cef7 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -454,20 +454,20 @@ def assert_equal_(o1, o2): def test_roundtrip(): for arr in basic_arrays + record_arrays: arr2 = roundtrip(arr) - yield assert_array_equal, arr, arr2 + assert_array_equal(arr, arr2) def test_roundtrip_randsize(): for arr in basic_arrays + record_arrays: if arr.dtype != object: arr2 = roundtrip_randsize(arr) - yield assert_array_equal, arr, arr2 + assert_array_equal(arr, arr2) def test_roundtrip_truncated(): for arr in basic_arrays: if arr.dtype != object: - yield assert_raises, ValueError, roundtrip_truncated, arr + assert_raises(ValueError, roundtrip_truncated, arr) def test_long_str(): @@ -508,7 +508,7 @@ def test_memmap_roundtrip(): fp = open(mfn, 'rb') memmap_bytes = fp.read() fp.close() - yield assert_equal_, normal_bytes, memmap_bytes + assert_equal_(normal_bytes, memmap_bytes) # Check that reading the file using memmap works. ma = format.open_memmap(nfn, mode='r') @@ -728,13 +728,13 @@ def test_read_magic(): def test_read_magic_bad_magic(): for magic in malformed_magic: f = BytesIO(magic) - yield raises(ValueError)(format.read_magic), f + assert_raises(ValueError, format.read_array, f) def test_read_version_1_0_bad_magic(): for magic in bad_version_magic + malformed_magic: f = BytesIO(magic) - yield raises(ValueError)(format.read_array), f + assert_raises(ValueError, format.read_array, f) def test_bad_magic_args(): diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index b5e06dad0..0520ce580 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -245,7 +245,7 @@ class TestIndexExpression(object): class TestIx_(object): def test_regression_1(self): - # Test empty inputs create ouputs of indexing type, gh-5804 + # Test empty inputs create outputs of indexing type, gh-5804 # Test both lists and arrays for func in (range, np.arange): a, = np.ix_(func(0)) diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 75a8e4968..d05fcd543 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -376,7 +376,7 @@ class TestSaveTxt(object): lines = c.readlines() assert_equal(lines, [b'01 : 2.0\n', b'03 : 4.0\n']) - # Specify delimiter, should be overiden + # Specify delimiter, should be overridden c = BytesIO() np.savetxt(c, a, fmt='%02d : %3.1f', delimiter=',') c.seek(0) @@ -1096,7 +1096,7 @@ class TestFromTxt(LoadTxtBase): assert_equal(test, control) def test_array(self): - # Test outputing a standard ndarray + # Test outputting a standard ndarray data = TextIO('1 2\n3 4') control = np.array([[1, 2], [3, 4]], dtype=int) test = np.ndfromtxt(data, dtype=int) @@ -2056,6 +2056,13 @@ M 33 21.99 assert_(isinstance(test, np.recarray)) assert_equal(test, control) + #gh-10394 + data = TextIO('color\n"red"\n"blue"') + test = np.recfromcsv(data, converters={0: lambda x: x.strip(b'\"')}) + control = np.array([('red',), ('blue',)], dtype=[('color', (bytes, 4))]) + assert_equal(test.dtype, control.dtype) + assert_equal(test, control) + def test_max_rows(self): # Test the `max_rows` keyword argument. data = '1 2\n3 4\n5 6\n7 8\n9 10\n' @@ -2226,7 +2233,7 @@ class TestPathUsage(object): @dec.skipif(Path is None, "No pathlib.Path") def test_ndfromtxt(self): - # Test outputing a standard ndarray + # Test outputting a standard ndarray with temppath(suffix='.txt') as path: path = Path(path) with path.open('w') as f: @@ -2292,7 +2299,7 @@ def test_gzip_load(): def test_gzip_loadtxt(): - # Thanks to another windows brokeness, we can't use + # Thanks to another windows brokenness, we can't use # NamedTemporaryFile: a file created from this function cannot be # reopened by another open call. So we first put the gzipped string # of the test reference array, write it to a securely opened file, diff --git a/numpy/lib/tests/test_polynomial.py b/numpy/lib/tests/test_polynomial.py index 9a4650825..03915cead 100644 --- a/numpy/lib/tests/test_polynomial.py +++ b/numpy/lib/tests/test_polynomial.py @@ -222,6 +222,14 @@ class TestDocs(object): assert_equal(p == p2, False) assert_equal(p != p2, True) + def test_polydiv(self): + b = np.poly1d([2, 6, 6, 1]) + a = np.poly1d([-1j, (1+2j), -(2+1j), 1]) + q, r = np.polydiv(b, a) + assert_equal(q.coeffs.dtype, np.complex128) + assert_equal(r.coeffs.dtype, np.complex128) + assert_equal(q*a + r, b) + def test_poly_coeffs_immutable(self): """ Coefficients should not be modifiable """ p = np.poly1d([1, 2, 3]) diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py index 8945b61ea..ce8ef2f15 100644 --- a/numpy/lib/tests/test_type_check.py +++ b/numpy/lib/tests/test_type_check.py @@ -359,6 +359,7 @@ class TestNanToNum(object): assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0])) assert_(vals[1] == 0) assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) + assert_equal(type(vals), np.ndarray) # perform the same test but in-place with np.errstate(divide='ignore', invalid='ignore'): @@ -369,16 +370,27 @@ class TestNanToNum(object): assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0])) assert_(vals[1] == 0) assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) + assert_equal(type(vals), np.ndarray) + + def test_array(self): + vals = nan_to_num([1]) + assert_array_equal(vals, np.array([1], int)) + assert_equal(type(vals), np.ndarray) def test_integer(self): vals = nan_to_num(1) assert_all(vals == 1) - vals = nan_to_num([1]) - assert_array_equal(vals, np.array([1], int)) + assert_equal(type(vals), np.int_) + + def test_float(self): + vals = nan_to_num(1.0) + assert_all(vals == 1.0) + assert_equal(type(vals), np.float_) def test_complex_good(self): vals = nan_to_num(1+1j) assert_all(vals == 1+1j) + assert_equal(type(vals), np.complex_) def test_complex_bad(self): with np.errstate(divide='ignore', invalid='ignore'): @@ -387,6 +399,7 @@ class TestNanToNum(object): vals = nan_to_num(v) # !! This is actually (unexpectedly) zero assert_all(np.isfinite(vals)) + assert_equal(type(vals), np.complex_) def test_complex_bad2(self): with np.errstate(divide='ignore', invalid='ignore'): @@ -394,6 +407,7 @@ class TestNanToNum(object): v += np.array(-1+1.j)/0. vals = nan_to_num(v) assert_all(np.isfinite(vals)) + assert_equal(type(vals), np.complex_) # Fixme #assert_all(vals.imag > 1e10) and assert_all(np.isfinite(vals)) # !! This is actually (unexpectedly) positive diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index 5c7528d4f..1664e6ebb 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -215,7 +215,7 @@ def iscomplex(x): if issubclass(ax.dtype.type, _nx.complexfloating): return ax.imag != 0 res = zeros(ax.shape, bool) - return +res # convet to array-scalar if needed + return +res # convert to array-scalar if needed def isreal(x): """ @@ -330,7 +330,7 @@ def _getmaxmin(t): def nan_to_num(x, copy=True): """ - Replace nan with zero and inf with large finite numbers. + Replace NaN with zero and infinity with large finite numbers. If `x` is inexact, NaN is replaced by zero, and infinity and -infinity replaced by the respectively largest and most negative finite floating @@ -343,7 +343,7 @@ def nan_to_num(x, copy=True): Parameters ---------- - x : array_like + x : scalar or array_like Input data. copy : bool, optional Whether to create a copy of `x` (True) or to replace values @@ -374,6 +374,12 @@ def nan_to_num(x, copy=True): Examples -------- + >>> np.nan_to_num(np.inf) + 1.7976931348623157e+308 + >>> np.nan_to_num(-np.inf) + -1.7976931348623157e+308 + >>> np.nan_to_num(np.nan) + 0.0 >>> x = np.array([np.inf, -np.inf, np.nan, -128, 128]) >>> np.nan_to_num(x) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, @@ -386,20 +392,21 @@ def nan_to_num(x, copy=True): """ x = _nx.array(x, subok=True, copy=copy) xtype = x.dtype.type + + isscalar = (x.ndim == 0) + if not issubclass(xtype, _nx.inexact): - return x + return x[()] if isscalar else x iscomplex = issubclass(xtype, _nx.complexfloating) - isscalar = (x.ndim == 0) - x = x[None] if isscalar else x dest = (x.real, x.imag) if iscomplex else (x,) maxf, minf = _getmaxmin(x.real.dtype) for d in dest: _nx.copyto(d, 0.0, where=isnan(d)) _nx.copyto(d, maxf, where=isposinf(d)) _nx.copyto(d, minf, where=isneginf(d)) - return x[0] if isscalar else x + return x[()] if isscalar else x #----------------------------------------------------------------------------- @@ -579,7 +586,7 @@ def common_type(*arrays): an integer array, the minimum precision type that is returned is a 64-bit floating point dtype. - All input arrays except int64 and uint64 can be safely cast to the + All input arrays except int64 and uint64 can be safely cast to the returned dtype without loss of information. Parameters diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index e18eda0fb..1ecd334af 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -707,7 +707,7 @@ def lookfor(what, module=None, import_modules=True, regenerate=False, """ Do a keyword search on docstrings. - A list of of objects that matched the search is displayed, + A list of objects that matched the search is displayed, sorted by relevance. All given keywords need to be found in the docstring for it to be returned as a result, but the order does not matter. |