summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2019-05-26 12:04:44 -0700
committerStephan Hoyer <shoyer@google.com>2019-05-26 12:04:44 -0700
commit9216a1dd5391013bf0670bb3d9d07ef68c3b7d66 (patch)
tree99924c389973ffe0792f08cbd93730c5d10dd4fc
parent37df5e641f19d9e5221cb519532b6cc5647ebe53 (diff)
downloadnumpy-9216a1dd5391013bf0670bb3d9d07ef68c3b7d66.tar.gz
MAINT: Fixes tests with __array_function__ disabled
-rw-r--r--numpy/core/overrides.py4
-rw-r--r--numpy/core/shape_base.py11
-rw-r--r--numpy/core/tests/test_overrides.py4
-rw-r--r--numpy/lib/shape_base.py10
-rw-r--r--numpy/lib/ufunclike.py22
-rw-r--r--numpy/testing/tests/test_utils.py3
6 files changed, 44 insertions, 10 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 13688e0a7..04a5a995f 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -9,7 +9,7 @@ from numpy.core._multiarray_umath import (
from numpy.compat._inspect import getargspec
-ENABLE_ARRAY_FUNCTION = bool(
+ARRAY_FUNCTION_ENABLED = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
@@ -142,7 +142,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
Function suitable for decorating the implementation of a NumPy function.
"""
- if not ENABLE_ARRAY_FUNCTION:
+ if not ARRAY_FUNCTION_ENABLED:
def decorator(implementation):
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 3a037e7d2..45115adb6 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -273,6 +273,9 @@ def vstack(tup):
[4]])
"""
+ if not overrides.ARRAY_FUNCTION_ENABLED:
+ # raise warning if necessary
+ _arrays_for_stack_dispatcher(tup, stacklevel=2)
return _nx.concatenate([atleast_2d(_m) for _m in tup], 0)
@@ -324,6 +327,10 @@ def hstack(tup):
[3, 4]])
"""
+ if not overrides.ARRAY_FUNCTION_ENABLED:
+ # raise warning if necessary
+ _arrays_for_stack_dispatcher(tup, stacklevel=2)
+
arrs = [atleast_1d(_m) for _m in tup]
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
if arrs and arrs[0].ndim == 1:
@@ -400,6 +407,10 @@ def stack(arrays, axis=0, out=None):
[3, 4]])
"""
+ if not overrides.ARRAY_FUNCTION_ENABLED:
+ # raise warning if necessary
+ _arrays_for_stack_dispatcher(arrays, stacklevel=2)
+
arrays = [asanyarray(arr) for arr in arrays]
if not arrays:
raise ValueError('need at least one array to stack')
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index dc8bf0613..63b0e4539 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -9,13 +9,13 @@ from numpy.testing import (
assert_, assert_equal, assert_raises, assert_raises_regex)
from numpy.core.overrides import (
_get_implementing_args, array_function_dispatch,
- verify_matching_signatures, ENABLE_ARRAY_FUNCTION)
+ verify_matching_signatures, ARRAY_FUNCTION_ENABLED)
from numpy.compat import pickle
import pytest
requires_array_function = pytest.mark.skipif(
- not ENABLE_ARRAY_FUNCTION,
+ not ARRAY_FUNCTION_ENABLED,
reason="__array_function__ dispatch not enabled.")
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index f2517b474..10f567bb4 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -7,7 +7,7 @@ import numpy.core.numeric as _nx
from numpy.core.numeric import (
asarray, zeros, outer, concatenate, array, asanyarray
)
-from numpy.core.fromnumeric import product, reshape, transpose
+from numpy.core.fromnumeric import reshape, transpose
from numpy.core.multiarray import normalize_axis_index
from numpy.core import overrides
from numpy.core import vstack, atleast_3d
@@ -628,6 +628,10 @@ def column_stack(tup):
[3, 4]])
"""
+ if not overrides.ARRAY_FUNCTION_ENABLED:
+ # raise warning if necessary
+ _arrays_for_stack_dispatcher(tup, stacklevel=2)
+
arrays = []
for v in tup:
arr = array(v, copy=False, subok=True)
@@ -692,6 +696,10 @@ def dstack(tup):
[[3, 4]]])
"""
+ if not overrides.ARRAY_FUNCTION_ENABLED:
+ # raise warning if necessary
+ _arrays_for_stack_dispatcher(tup, stacklevel=2)
+
return _nx.concatenate([atleast_3d(_m) for _m in tup], 2)
diff --git a/numpy/lib/ufunclike.py b/numpy/lib/ufunclike.py
index 8452604d9..96fd5b319 100644
--- a/numpy/lib/ufunclike.py
+++ b/numpy/lib/ufunclike.py
@@ -8,7 +8,9 @@ from __future__ import division, absolute_import, print_function
__all__ = ['fix', 'isneginf', 'isposinf']
import numpy.core.numeric as nx
-from numpy.core.overrides import array_function_dispatch
+from numpy.core.overrides import (
+ array_function_dispatch, ARRAY_FUNCTION_ENABLED,
+)
import warnings
import functools
@@ -43,7 +45,7 @@ def _fix_out_named_y(f):
Allow the out argument to be passed as the name `y` (deprecated)
This decorator should only be used if _deprecate_out_named_y is used on
- a corresponding dispatcher fucntion.
+ a corresponding dispatcher function.
"""
@functools.wraps(f)
def func(x, out=None, **kwargs):
@@ -55,13 +57,23 @@ def _fix_out_named_y(f):
return func
+def _fix_and_maybe_deprecate_out_named_y(f):
+ """
+ Use the appropriate decorator, depending upon if dispatching is being used.
+ """
+ if ARRAY_FUNCTION_ENABLED:
+ return _fix_out_named_y(f)
+ else:
+ return _deprecate_out_named_y(f)
+
+
@_deprecate_out_named_y
def _dispatcher(x, out=None):
return (x, out)
@array_function_dispatch(_dispatcher, verify=False, module='numpy')
-@_fix_out_named_y
+@_fix_and_maybe_deprecate_out_named_y
def fix(x, out=None):
"""
Round to nearest integer towards zero.
@@ -108,7 +120,7 @@ def fix(x, out=None):
@array_function_dispatch(_dispatcher, verify=False, module='numpy')
-@_fix_out_named_y
+@_fix_and_maybe_deprecate_out_named_y
def isposinf(x, out=None):
"""
Test element-wise for positive infinity, return result as bool array.
@@ -177,7 +189,7 @@ def isposinf(x, out=None):
@array_function_dispatch(_dispatcher, verify=False, module='numpy')
-@_fix_out_named_y
+@_fix_and_maybe_deprecate_out_named_y
def isneginf(x, out=None):
"""
Test element-wise for negative infinity, return result as bool array.
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 643d143ee..f17a3ccd3 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -17,6 +17,7 @@ from numpy.testing import (
clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
)
+from numpy.core.overrides import ARRAY_FUNCTION_ENABLED
class _GenericTest(object):
@@ -179,6 +180,8 @@ class TestArrayEqual(_GenericTest):
self._test_not_equal(a, b)
self._test_not_equal(b, a)
+ @pytest.mark.skipif(
+ not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
def test_subclass_that_does_not_implement_npall(self):
class MyArray(np.ndarray):
def __array_function__(self, *args, **kwargs):