summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py124
-rw-r--r--numpy/lib/tests/test_function_base.py49
-rw-r--r--numpy/lib/tests/test_utils.py20
-rw-r--r--numpy/lib/utils.py44
4 files changed, 134 insertions, 103 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 0b23dbebd..48b0a0830 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1332,6 +1332,10 @@ def interp(x, xp, fp, left=None, right=None, period=None):
If `xp` or `fp` are not 1-D sequences
If `period == 0`
+ See Also
+ --------
+ scipy.interpolate
+
Notes
-----
The x-coordinate sequence is expected to be increasing, but this is not
@@ -3869,15 +3873,20 @@ def _quantile_is_valid(q):
return True
+def _lerp(a, b, t, out=None):
+ """ Linearly interpolate from a to b by a factor of t """
+ return add(a*(1 - t), b*t, out=out)
+
+
def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,
interpolation='linear', keepdims=False):
a = asarray(a)
- if q.ndim == 0:
- # Do not allow 0-d arrays because following code fails for scalar
- zerod = True
- q = q[None]
- else:
- zerod = False
+
+ # ufuncs cause 0d array results to decay to scalars (see gh-13105), which
+ # makes them problematic for __setitem__ and attribute access. As a
+ # workaround, we call this on the result of every ufunc on a possibly-0d
+ # array.
+ not_scalar = np.asanyarray
# prepare a for partitioning
if overwrite_input:
@@ -3894,9 +3903,14 @@ def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,
if axis is None:
axis = 0
- Nx = ap.shape[axis]
- indices = q * (Nx - 1)
+ if q.ndim > 2:
+ # The code below works fine for nd, but it might not have useful
+ # semantics. For now, keep the supported dimensions the same as it was
+ # before.
+ raise ValueError("q must be a scalar or 1d")
+ Nx = ap.shape[axis]
+ indices = not_scalar(q * (Nx - 1))
# round fractional indices according to interpolation method
if interpolation == 'lower':
indices = floor(indices).astype(intp)
@@ -3913,74 +3927,60 @@ def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,
"interpolation can only be 'linear', 'lower' 'higher', "
"'midpoint', or 'nearest'")
- n = np.array(False, dtype=bool) # check for nan's flag
- if np.issubdtype(indices.dtype, np.integer): # take the points along axis
- # Check if the array contains any nan's
- if np.issubdtype(a.dtype, np.inexact):
- indices = concatenate((indices, [-1]))
+ # The dimensions of `q` are prepended to the output shape, so we need the
+ # axis being sampled from `ap` to be first.
+ ap = np.moveaxis(ap, axis, 0)
+ del axis
- ap.partition(indices, axis=axis)
- # ensure axis with q-th is first
- ap = np.moveaxis(ap, axis, 0)
- axis = 0
+ if np.issubdtype(indices.dtype, np.integer):
+ # take the points along axis
- # Check if the array contains any nan's
if np.issubdtype(a.dtype, np.inexact):
- indices = indices[:-1]
- n = np.isnan(ap[-1:, ...])
-
- if zerod:
- indices = indices[0]
- r = take(ap, indices, axis=axis, out=out)
-
- else: # weight the points above and below the indices
- indices_below = floor(indices).astype(intp)
- indices_above = indices_below + 1
- indices_above[indices_above > Nx - 1] = Nx - 1
-
- # Check if the array contains any nan's
- if np.issubdtype(a.dtype, np.inexact):
- indices_above = concatenate((indices_above, [-1]))
+ # may contain nan, which would sort to the end
+ ap.partition(concatenate((indices.ravel(), [-1])), axis=0)
+ n = np.isnan(ap[-1])
+ else:
+ # cannot contain nan
+ ap.partition(indices.ravel(), axis=0)
+ n = np.array(False, dtype=bool)
- ap.partition(concatenate((indices_below, indices_above)), axis=axis)
+ r = take(ap, indices, axis=0, out=out)
- # ensure axis with q-th is first
- ap = np.moveaxis(ap, axis, 0)
- axis = 0
+ else:
+ # weight the points above and below the indices
- weights_shape = [1] * ap.ndim
- weights_shape[axis] = len(indices)
- weights_above = (indices - indices_below).reshape(weights_shape)
+ indices_below = not_scalar(floor(indices)).astype(intp)
+ indices_above = not_scalar(indices_below + 1)
+ indices_above[indices_above > Nx - 1] = Nx - 1
- # Check if the array contains any nan's
if np.issubdtype(a.dtype, np.inexact):
- indices_above = indices_above[:-1]
- n = np.isnan(ap[-1:, ...])
+ # may contain nan, which would sort to the end
+ ap.partition(concatenate((
+ indices_below.ravel(), indices_above.ravel(), [-1]
+ )), axis=0)
+ n = np.isnan(ap[-1])
+ else:
+ # cannot contain nan
+ ap.partition(concatenate((
+ indices_below.ravel(), indices_above.ravel()
+ )), axis=0)
+ n = np.array(False, dtype=bool)
- x1 = take(ap, indices_below, axis=axis) * (1 - weights_above)
- x2 = take(ap, indices_above, axis=axis) * weights_above
+ weights_shape = indices.shape + (1,) * (ap.ndim - 1)
+ weights_above = not_scalar(indices - indices_below).reshape(weights_shape)
- if zerod:
- x1 = x1.squeeze(0)
- x2 = x2.squeeze(0)
+ x_below = take(ap, indices_below, axis=0)
+ x_above = take(ap, indices_above, axis=0)
- r = add(x1, x2, out=out)
+ r = _lerp(x_below, x_above, weights_above, out=out)
+ # if any slice contained a nan, then all results on that slice are also nan
if np.any(n):
- if zerod:
- if ap.ndim == 1:
- if out is not None:
- out[...] = a.dtype.type(np.nan)
- r = out
- else:
- r = a.dtype.type(np.nan)
- else:
- r[..., n.squeeze(0)] = a.dtype.type(np.nan)
+ if r.ndim == 0 and out is None:
+ # can't write to a scalar
+ r = a.dtype.type(np.nan)
else:
- if r.ndim == 1:
- r[:] = a.dtype.type(np.nan)
- else:
- r[..., n.repeat(q.size, 0)] = a.dtype.type(np.nan)
+ r[..., n] = a.dtype.type(np.nan)
return r
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index b4e928273..9ba0be56a 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2423,6 +2423,15 @@ class TestBincount:
assert_equal(sys.getrefcount(np.dtype(np.intp)), intp_refcount)
assert_equal(sys.getrefcount(np.dtype(np.double)), double_refcount)
+ @pytest.mark.parametrize("vals", [[[2, 2]], 2])
+ def test_error_not_1d(self, vals):
+ # Test that values has to be 1-D (both as array and nested list)
+ vals_arr = np.asarray(vals)
+ with assert_raises(ValueError):
+ np.bincount(vals_arr)
+ with assert_raises(ValueError):
+ np.bincount(vals)
+
class TestInterp:
@@ -3350,12 +3359,50 @@ class TestAdd_newdoc:
@pytest.mark.skipif(sys.flags.optimize == 2, reason="Python running -OO")
@pytest.mark.xfail(IS_PYPY, reason="PyPy does not modify tp_doc")
def test_add_doc(self):
- # test np.add_newdoc
+ # test that np.add_newdoc did attach a docstring successfully:
tgt = "Current flat index into the array."
assert_equal(np.core.flatiter.index.__doc__[:len(tgt)], tgt)
assert_(len(np.core.ufunc.identity.__doc__) > 300)
assert_(len(np.lib.index_tricks.mgrid.__doc__) > 300)
+ @pytest.mark.skipif(sys.flags.optimize == 2, reason="Python running -OO")
+ def test_errors_are_ignored(self):
+ prev_doc = np.core.flatiter.index.__doc__
+ # nothing changed, but error ignored, this should probably
+ # give a warning (or even error) in the future.
+ np.add_newdoc("numpy.core", "flatiter", ("index", "bad docstring"))
+ assert prev_doc == np.core.flatiter.index.__doc__
+
+
+class TestAddDocstring():
+ # Test should possibly be moved, but it also fits to be close to
+ # the newdoc tests...
+ @pytest.mark.skipif(sys.flags.optimize == 2, reason="Python running -OO")
+ @pytest.mark.skipif(IS_PYPY, reason="PyPy does not modify tp_doc")
+ def test_add_same_docstring(self):
+ # test for attributes (which are C-level defined)
+ np.add_docstring(np.ndarray.flat, np.ndarray.flat.__doc__)
+ # And typical functions:
+ def func():
+ """docstring"""
+ return
+
+ np.add_docstring(func, func.__doc__)
+
+ @pytest.mark.skipif(sys.flags.optimize == 2, reason="Python running -OO")
+ def test_different_docstring_fails(self):
+ # test for attributes (which are C-level defined)
+ with assert_raises(RuntimeError):
+ np.add_docstring(np.ndarray.flat, "different docstring")
+ # And typical functions:
+ def func():
+ """docstring"""
+ return
+
+ with assert_raises(RuntimeError):
+ np.add_docstring(func, "different docstring")
+
+
class TestSortComplex:
@pytest.mark.parametrize("type_in, type_out", [
diff --git a/numpy/lib/tests/test_utils.py b/numpy/lib/tests/test_utils.py
index c96bf795a..261cfef5d 100644
--- a/numpy/lib/tests/test_utils.py
+++ b/numpy/lib/tests/test_utils.py
@@ -49,7 +49,7 @@ def old_func5(self, x):
Bizarre indentation.
"""
return x
-new_func5 = deprecate(old_func5)
+new_func5 = deprecate(old_func5, message="This function is\ndeprecated.")
def old_func6(self, x):
@@ -74,10 +74,20 @@ def test_deprecate_fn():
@pytest.mark.skipif(sys.flags.optimize == 2, reason="-OO discards docstrings")
-def test_deprecate_help_indentation():
- _compare_docs(old_func4, new_func4)
- _compare_docs(old_func5, new_func5)
- _compare_docs(old_func6, new_func6)
+@pytest.mark.parametrize('old_func, new_func', [
+ (old_func4, new_func4),
+ (old_func5, new_func5),
+ (old_func6, new_func6),
+])
+def test_deprecate_help_indentation(old_func, new_func):
+ _compare_docs(old_func, new_func)
+ # Ensure we don't mess up the indentation
+ for knd, func in (('old', old_func), ('new', new_func)):
+ for li, line in enumerate(func.__doc__.split('\n')):
+ if li == 0:
+ assert line.startswith(' ') or not line.startswith(' '), knd
+ elif line:
+ assert line.startswith(' '), knd
def _compare_docs(old_func, new_func):
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index f233c7240..d511c2a40 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -1,5 +1,6 @@
import os
import sys
+import textwrap
import types
import re
import warnings
@@ -9,9 +10,6 @@ from numpy.core.overrides import set_module
from numpy.core import ndarray, ufunc, asarray
import numpy as np
-# getargspec and formatargspec were removed in Python 3.6
-from numpy.compat import getargspec, formatargspec
-
__all__ = [
'issubclass_', 'issubsctype', 'issubdtype', 'deprecate',
'deprecate_with_doc', 'get_include', 'info', 'source', 'who',
@@ -117,6 +115,7 @@ class _Deprecate:
break
skip += len(line) + 1
doc = doc[skip:]
+ depdoc = textwrap.indent(depdoc, ' ' * indent)
doc = '\n\n'.join([depdoc, doc])
newfunc.__doc__ = doc
try:
@@ -552,9 +551,12 @@ def info(object=None, maxwidth=76, output=sys.stdout, toplevel='numpy'):
file=output
)
- elif inspect.isfunction(object):
+ elif inspect.isfunction(object) or inspect.ismethod(object):
name = object.__name__
- arguments = formatargspec(*getargspec(object))
+ try:
+ arguments = str(inspect.signature(object))
+ except Exception:
+ arguments = "()"
if len(name+arguments) > maxwidth:
argstr = _split_line(name, arguments, maxwidth)
@@ -566,18 +568,10 @@ def info(object=None, maxwidth=76, output=sys.stdout, toplevel='numpy'):
elif inspect.isclass(object):
name = object.__name__
- arguments = "()"
try:
- if hasattr(object, '__init__'):
- arguments = formatargspec(
- *getargspec(object.__init__.__func__)
- )
- arglist = arguments.split(', ')
- if len(arglist) > 1:
- arglist[1] = "("+arglist[1]
- arguments = ", ".join(arglist[1:])
+ arguments = str(inspect.signature(object))
except Exception:
- pass
+ arguments = "()"
if len(name+arguments) > maxwidth:
argstr = _split_line(name, arguments, maxwidth)
@@ -605,26 +599,6 @@ def info(object=None, maxwidth=76, output=sys.stdout, toplevel='numpy'):
)
print(" %s -- %s" % (meth, methstr), file=output)
- elif inspect.ismethod(object):
- name = object.__name__
- arguments = formatargspec(
- *getargspec(object.__func__)
- )
- arglist = arguments.split(', ')
- if len(arglist) > 1:
- arglist[1] = "("+arglist[1]
- arguments = ", ".join(arglist[1:])
- else:
- arguments = "()"
-
- if len(name+arguments) > maxwidth:
- argstr = _split_line(name, arguments, maxwidth)
- else:
- argstr = name + arguments
-
- print(" " + argstr + "\n", file=output)
- print(inspect.getdoc(object), file=output)
-
elif hasattr(object, '__doc__'):
print(inspect.getdoc(object), file=output)