summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-08-13 18:16:26 -0600
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-09-04 21:45:48 +0200
commit175cea4dc0590ae520f32476ffb9129b8524bcad (patch)
tree6d48e045a3fbdb3b0a8088732ec8d62db5f11f33 /numpy/core
parent2995792b90add947f8898ea36410473f37b94509 (diff)
downloadnumpy-175cea4dc0590ae520f32476ffb9129b8524bcad.tar.gz
ENH: When cblas is available use it in descr->f->dot.
Importing _dotblas currently executes _dotblas.alterdot, which replaces the default descr->f->dot function with a cblas based version for float, double, complex float, and complex double data types. This PR changes the default descr->f->dot to use cblas whenever it is available. After this change, the alterdot and restoredot functions serve no purpose, so are changed to do nothing and deprecated. Note that those functions were already doing nothing when _dotblas was not available.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/blasdot/_dotblas.c135
-rw-r--r--numpy/core/bscript16
-rw-r--r--numpy/core/numeric.py18
-rw-r--r--numpy/core/setup.py21
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src189
-rw-r--r--numpy/core/tests/test_blasdot.py4
-rw-r--r--numpy/core/tests/test_deprecations.py42
7 files changed, 226 insertions, 199 deletions
diff --git a/numpy/core/blasdot/_dotblas.c b/numpy/core/blasdot/_dotblas.c
index 48aa39ff8..c1b4e7b05 100644
--- a/numpy/core/blasdot/_dotblas.c
+++ b/numpy/core/blasdot/_dotblas.c
@@ -24,121 +24,6 @@ static char module_doc[] =
static PyArray_DotFunc *oldFunctions[NPY_NTYPES];
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-
-/*
- * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done
- * (BLAS won't handle negative or zero strides the way we want).
- */
-static NPY_INLINE int
-blas_stride(npy_intp stride, unsigned itemsize)
-{
- if (stride <= 0 || stride % itemsize != 0) {
- return 0;
- }
- stride /= itemsize;
-
- if (stride > INT_MAX) {
- return 0;
- }
- return stride;
-}
-
-/*
- * The following functions do a "chunked" dot product using BLAS when
- * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not
- * handle more than INT_MAX elements per call.
- *
- * The chunksize is the greatest power of two less than INT_MAX.
- */
-#if NPY_MAX_INTP > INT_MAX
-# define CHUNKSIZE (INT_MAX / 2 + 1)
-#else
-# define CHUNKSIZE NPY_MAX_INTP
-#endif
-
-static void
-FLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
- npy_intp n, void *tmp)
-{
- int na = blas_stride(stridea, sizeof(float));
- int nb = blas_stride(strideb, sizeof(float));
-
- if (na && nb) {
- double r = 0.; /* double for stability */
- float *fa = a, *fb = b;
-
- while (n > 0) {
- int chunk = MIN(n, CHUNKSIZE);
-
- r += cblas_sdot(chunk, fa, na, fb, nb);
- fa += chunk * na;
- fb += chunk * nb;
- n -= chunk;
- }
- *((float *)res) = r;
- }
- else {
- oldFunctions[NPY_FLOAT](a, stridea, b, strideb, res, n, tmp);
- }
-}
-
-static void
-DOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
- npy_intp n, void *tmp)
-{
- int na = blas_stride(stridea, sizeof(double));
- int nb = blas_stride(strideb, sizeof(double));
-
- if (na && nb) {
- double r = 0.;
- double *da = a, *db = b;
-
- while (n > 0) {
- int chunk = MIN(n, CHUNKSIZE);
-
- r += cblas_ddot(chunk, da, na, db, nb);
- da += chunk * na;
- db += chunk * nb;
- n -= chunk;
- }
- *((double *)res) = r;
- }
- else {
- oldFunctions[NPY_DOUBLE](a, stridea, b, strideb, res, n, tmp);
- }
-}
-
-static void
-CFLOAT_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
- npy_intp n, void *tmp)
-{
- int na = blas_stride(stridea, sizeof(npy_cfloat));
- int nb = blas_stride(strideb, sizeof(npy_cfloat));
-
- if (na && nb) {
- cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res);
- }
- else {
- oldFunctions[NPY_CFLOAT](a, stridea, b, strideb, res, n, tmp);
- }
-}
-
-static void
-CDOUBLE_dot(void *a, npy_intp stridea, void *b, npy_intp strideb, void *res,
- npy_intp n, void *tmp)
-{
- int na = blas_stride(stridea, sizeof(npy_cdouble));
- int nb = blas_stride(strideb, sizeof(npy_cdouble));
-
- if (na && nb) {
- cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb,
- (double *)res);
- }
- else {
- oldFunctions[NPY_CDOUBLE](a, stridea, b, strideb, res, n, tmp);
- }
-}
/*
* Helper: call appropriate BLAS dot function for typenum.
@@ -149,20 +34,8 @@ blas_dot(int typenum, npy_intp n,
void *a, npy_intp stridea, void *b, npy_intp strideb, void *res)
{
PyArray_DotFunc *dot = NULL;
- switch (typenum) {
- case NPY_DOUBLE:
- dot = DOUBLE_dot;
- break;
- case NPY_FLOAT:
- dot = FLOAT_dot;
- break;
- case NPY_CDOUBLE:
- dot = CDOUBLE_dot;
- break;
- case NPY_CFLOAT:
- dot = CFLOAT_dot;
- break;
- }
+
+ dot = oldFunctions[typenum];
assert(dot != NULL);
dot(a, stridea, b, strideb, res, n, NULL);
}
@@ -257,19 +130,15 @@ dotblas_alterdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
if (!altered) {
descr = PyArray_DescrFromType(NPY_FLOAT);
oldFunctions[NPY_FLOAT] = descr->f->dotfunc;
- descr->f->dotfunc = (PyArray_DotFunc *)FLOAT_dot;
descr = PyArray_DescrFromType(NPY_DOUBLE);
oldFunctions[NPY_DOUBLE] = descr->f->dotfunc;
- descr->f->dotfunc = (PyArray_DotFunc *)DOUBLE_dot;
descr = PyArray_DescrFromType(NPY_CFLOAT);
oldFunctions[NPY_CFLOAT] = descr->f->dotfunc;
- descr->f->dotfunc = (PyArray_DotFunc *)CFLOAT_dot;
descr = PyArray_DescrFromType(NPY_CDOUBLE);
oldFunctions[NPY_CDOUBLE] = descr->f->dotfunc;
- descr->f->dotfunc = (PyArray_DotFunc *)CDOUBLE_dot;
altered = NPY_TRUE;
}
diff --git a/numpy/core/bscript b/numpy/core/bscript
index 416e16524..3306c5341 100644
--- a/numpy/core/bscript
+++ b/numpy/core/bscript
@@ -488,14 +488,22 @@ def pre_build(context):
pjoin('src', 'multiarray', 'usertypes.c')]
else:
sources = extension.sources
+
+ use = 'npysort npymath'
+ defines = ['_FILE_OFFSET_BITS=64',
+ '_LARGEFILE_SOURCE=1',
+ '_LARGEFILE64_SOURCE=1']
+
+ if bld.env.HAS_CBLAS:
+ use += ' CBLAS'
+ defines.append('HAVE_CBLAS')
+
includes = ["src/multiarray", "src/private"]
return context.default_builder(extension,
includes=includes,
source=sources,
- use="npysort npymath",
- defines=['_FILE_OFFSET_BITS=64',
- '_LARGEFILE_SOURCE=1',
- '_LARGEFILE64_SOURCE=1']
+ use=use,
+ defines=defines
)
context.register_builder("multiarray", builder_multiarray)
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index eb0c38b0b..4361ba5c1 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1078,31 +1078,33 @@ def outer(a, b, out=None):
# try to import blas optimized dot if available
envbak = os.environ.copy()
try:
- # importing this changes the dot function for basic 4 types
- # to blas-optimized versions.
-
# disables openblas affinity setting of the main thread that limits
# python threads or processes to one core
if 'OPENBLAS_MAIN_FREE' not in os.environ:
os.environ['OPENBLAS_MAIN_FREE'] = '1'
if 'GOTOBLAS_MAIN_FREE' not in os.environ:
os.environ['GOTOBLAS_MAIN_FREE'] = '1'
- from ._dotblas import dot, vdot, inner, alterdot, restoredot
+ from ._dotblas import dot, vdot, inner
except ImportError:
# docstrings are in add_newdocs.py
inner = multiarray.inner
dot = multiarray.dot
def vdot(a, b):
return dot(asarray(a).ravel().conj(), asarray(b).ravel())
- def alterdot():
- pass
- def restoredot():
- pass
finally:
os.environ.clear()
os.environ.update(envbak)
del envbak
+
+def alterdot():
+ warnings.warn("alterdot no longer does anything.", DeprecationWarning)
+
+
+def restoredot():
+ warnings.warn("restoredot no longer does anything.", DeprecationWarning)
+
+
def tensordot(a, b, axes=2):
"""
Compute tensor dot product along specified axes for arrays >= 1-D.
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index b81374d14..4c2af5e62 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -846,15 +846,22 @@ def configuration(parent_package='',top_path=None):
multiarray_src = [join('src', 'multiarray', 'multiarraymodule_onefile.c')]
multiarray_src.append(generate_multiarray_templated_sources)
+ blas_info = get_info('blas_opt', 0)
+ if blas_info and ('HAVE_CBLAS', None) in blas_info.get('define_macros', []):
+ extra_info = blas_info
+ else:
+ extra_info = {}
+
config.add_extension('multiarray',
- sources = multiarray_src +
+ sources=multiarray_src +
[generate_config_h,
- generate_numpyconfig_h,
- generate_numpy_api,
- join(codegen_dir, 'generate_numpy_api.py'),
- join('*.py')],
- depends = deps + multiarray_deps,
- libraries = ['npymath', 'npysort'])
+ generate_numpyconfig_h,
+ generate_numpy_api,
+ join(codegen_dir, 'generate_numpy_api.py'),
+ join('*.py')],
+ depends=deps + multiarray_deps,
+ libraries=['npymath', 'npysort'],
+ extra_info=extra_info)
#######################################################################
# umath module #
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index e4904acfc..1b7bbde63 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -3,6 +3,11 @@
#include "Python.h"
#include "structmember.h"
+#include <limits.h>
+#if defined(HAVE_CBLAS)
+#include <cblas.h>
+#endif
+
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define _MULTIARRAYMODULE
#include "numpy/npy_common.h"
@@ -3094,6 +3099,149 @@ static int
* dot means inner product
*/
+/************************** MAYBE USE CBLAS *********************************/
+
+/*
+ * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done
+ * (BLAS won't handle negative or zero strides the way we want).
+ */
+#if defined(HAVE_CBLAS)
+static NPY_INLINE int
+blas_stride(npy_intp stride, unsigned itemsize)
+{
+ /*
+ * Should probably check pointer alignment also, but this may cause
+ * problems if we require complex to be 16 byte aligned.
+ */
+ if (stride > 0 && npy_is_aligned((void *)stride, itemsize)) {
+ stride /= itemsize;
+ if (stride <= INT_MAX) {
+ return stride;
+ }
+ }
+ return 0;
+}
+#endif
+
+
+/*
+ * The following functions do a "chunked" dot product using BLAS when
+ * sizeof(npy_intp) > sizeof(int), because BLAS libraries can typically not
+ * handle more than INT_MAX elements per call.
+ *
+ * The chunksize is the greatest power of two less than INT_MAX.
+ */
+#if NPY_MAX_INTP > INT_MAX
+# define CHUNKSIZE (INT_MAX / 2 + 1)
+#else
+# define CHUNKSIZE NPY_MAX_INTP
+#endif
+
+/**begin repeat
+ *
+ * #name = FLOAT, DOUBLE#
+ * #type = npy_float, npy_double#
+ * #prefix = s, d#
+ */
+static void
+@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op,
+ npy_intp n, void *NPY_UNUSED(ignore))
+{
+#if defined(HAVE_CBLAS)
+ int is1b = blas_stride(is1, sizeof(@type@));
+ int is2b = blas_stride(is2, sizeof(@type@));
+
+ if (is1b && is2b)
+ {
+ double sum = 0.; /* double for stability */
+
+ while (n > 0) {
+ int chunk = n < CHUNKSIZE ? n : CHUNKSIZE;
+
+ sum += cblas_@prefix@dot(chunk,
+ (@type@ *) ip1, is1b,
+ (@type@ *) ip2, is2b);
+ /* use char strides here */
+ ip1 += chunk * is1;
+ ip2 += chunk * is2;
+ n -= chunk;
+ }
+ *((@type@ *)op) = (@type@)sum;
+ }
+ else
+#endif
+ {
+ @type@ sum = (@type@)0; /* could make this double */
+ npy_intp i;
+
+ for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
+ const @type@ ip1r = *((@type@ *)ip1);
+ const @type@ ip2r = *((@type@ *)ip2);
+
+ sum += ip1r * ip2r;
+ }
+ *((@type@ *)op) = sum;
+ }
+}
+/**end repeat**/
+
+/**begin repeat
+ *
+ * #name = CFLOAT, CDOUBLE#
+ * #ctype = npy_cfloat, npy_cdouble#
+ * #type = npy_float, npy_double#
+ * #prefix = c, z#
+ */
+static void
+@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
+ char *op, npy_intp n, void *NPY_UNUSED(ignore))
+{
+#if defined(HAVE_CBLAS)
+ int is1b = blas_stride(is1, sizeof(@ctype@));
+ int is2b = blas_stride(is2, sizeof(@ctype@));
+
+ if (is1b && is2b) {
+ double sum[2] = {0., 0.}; /* double for stability */
+
+ while (n > 0) {
+ int chunk = n < CHUNKSIZE ? n : CHUNKSIZE;
+ @type@ tmp[2];
+
+ cblas_@prefix@dotu_sub((int)n, ip1, is1b, ip2, is2b, tmp);
+ sum[0] += (double)tmp[0];
+ sum[1] += (double)tmp[1];
+ /* use char strides here */
+ ip1 += chunk * is1;
+ ip2 += chunk * is2;
+ n -= chunk;
+ }
+ ((@type@ *)op)[0] = (@type@)sum[0];
+ ((@type@ *)op)[1] = (@type@)sum[1];
+ }
+ else
+#endif
+ {
+ @type@ sumr = (@type@)0.0;
+ @type@ sumi = (@type@)0.0;
+ npy_intp i;
+
+ for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
+ const @type@ ip1r = ((@type@ *)ip1)[0];
+ const @type@ ip1i = ((@type@ *)ip1)[1];
+ const @type@ ip2r = ((@type@ *)ip2)[0];
+ const @type@ ip2i = ((@type@ *)ip2)[1];
+
+ sumr += ip1r * ip2r - ip1i * ip2i;
+ sumi += ip1i * ip2r + ip1r * ip2i;
+ }
+ ((@type@ *)op)[0] = sumr;
+ ((@type@ *)op)[1] = sumi;
+ }
+}
+/**end repeat**/
+
+/**************************** NO CBLAS VERSIONS *****************************/
+
static void
BOOL_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
void *NPY_UNUSED(ignore))
@@ -3114,16 +3262,13 @@ BOOL_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
*
* #name = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG,
- * FLOAT, DOUBLE, LONGDOUBLE,
- * DATETIME, TIMEDELTA#
+ * LONGDOUBLE, DATETIME, TIMEDELTA#
* #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
- * npy_float, npy_double, npy_longdouble,
- * npy_datetime, npy_timedelta#
+ * npy_longdouble, npy_datetime, npy_timedelta#
* #out = npy_long, npy_ulong, npy_long, npy_ulong, npy_long, npy_ulong,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
- * npy_float, npy_double, npy_longdouble,
- * npy_datetime, npy_timedelta#
+ * npy_longdouble, npy_datetime, npy_timedelta#
*/
static void
@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
@@ -3141,8 +3286,8 @@ static void
/**end repeat**/
static void
-HALF_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
- void *NPY_UNUSED(ignore))
+HALF_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op,
+ npy_intp n, void *NPY_UNUSED(ignore))
{
float tmp = 0.0f;
npy_intp i;
@@ -3154,28 +3299,26 @@ HALF_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
*((npy_half *)op) = npy_float_to_half(tmp);
}
-/**begin repeat
- *
- * #name = CFLOAT, CDOUBLE, CLONGDOUBLE#
- * #type = npy_float, npy_double, npy_longdouble#
- */
-static void @name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
- char *op, npy_intp n, void *NPY_UNUSED(ignore))
+static void CLONGDOUBLE_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
+ char *op, npy_intp n, void *NPY_UNUSED(ignore))
{
- @type@ tmpr = (@type@)0.0, tmpi=(@type@)0.0;
+ npy_longdouble tmpr = 0.0L;
+ npy_longdouble tmpi = 0.0L;
npy_intp i;
for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
- tmpr += ((@type@ *)ip1)[0] * ((@type@ *)ip2)[0]
- - ((@type@ *)ip1)[1] * ((@type@ *)ip2)[1];
- tmpi += ((@type@ *)ip1)[1] * ((@type@ *)ip2)[0]
- + ((@type@ *)ip1)[0] * ((@type@ *)ip2)[1];
+ const npy_longdouble ip1r = ((npy_longdouble *)ip1)[0];
+ const npy_longdouble ip1i = ((npy_longdouble *)ip1)[1];
+ const npy_longdouble ip2r = ((npy_longdouble *)ip2)[0];
+ const npy_longdouble ip2i = ((npy_longdouble *)ip2)[1];
+
+ tmpr += ip1r * ip2r - ip1i * ip2i;
+ tmpi += ip1i * ip2r + ip1r * ip2i;
}
- ((@type@ *)op)[0] = tmpr; ((@type@ *)op)[1] = tmpi;
+ ((npy_longdouble *)op)[0] = tmpr;
+ ((npy_longdouble *)op)[1] = tmpi;
}
-/**end repeat**/
-
static void
OBJECT_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
void *NPY_UNUSED(ignore))
diff --git a/numpy/core/tests/test_blasdot.py b/numpy/core/tests/test_blasdot.py
index 17f77d2f5..c38dab187 100644
--- a/numpy/core/tests/test_blasdot.py
+++ b/numpy/core/tests/test_blasdot.py
@@ -26,12 +26,10 @@ except ImportError:
@dec.skipif(_dotblas is None, "Numpy is not compiled with _dotblas")
def test_blasdot_used():
- from numpy.core import dot, vdot, inner, alterdot, restoredot
+ from numpy.core import dot, vdot, inner
assert_(dot is _dotblas.dot)
assert_(vdot is _dotblas.vdot)
assert_(inner is _dotblas.inner)
- assert_(alterdot is _dotblas.alterdot)
- assert_(restoredot is _dotblas.restoredot)
def test_dot_2args():
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index ef56766f5..9e2248205 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -5,13 +5,11 @@ to document how deprecations should eventually be turned into errors.
"""
from __future__ import division, absolute_import, print_function
-import sys
import operator
import warnings
-from nose.plugins.skip import SkipTest
import numpy as np
-from numpy.testing import (dec, run_module_suite, assert_raises,
+from numpy.testing import (run_module_suite, assert_raises,
assert_warns, assert_array_equal, assert_)
@@ -34,11 +32,9 @@ class _DeprecationTestCase(object):
warnings.filterwarnings("always", message=self.message,
category=DeprecationWarning)
-
def tearDown(self):
self.warn_ctx.__exit__()
-
def assert_deprecated(self, function, num=1, ignore_others=False,
function_fails=False,
exceptions=(DeprecationWarning,), args=(), kwargs={}):
@@ -102,7 +98,6 @@ class _DeprecationTestCase(object):
if exceptions == tuple():
raise AssertionError("Error raised during function call")
-
def assert_not_deprecated(self, function, args=(), kwargs={}):
"""Test if DeprecationWarnings are given and raised.
@@ -143,6 +138,7 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase):
def test_indexing(self):
a = np.array([[[5]]])
+
def assert_deprecated(*args, **kwargs):
self.assert_deprecated(*args, exceptions=(IndexError,), **kwargs)
@@ -172,7 +168,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase):
assert_deprecated(lambda: a[0.0:, 0.0], num=2)
assert_deprecated(lambda: a[0.0:, 0.0,:], num=2)
-
def test_valid_indexing(self):
a = np.array([[[5]]])
assert_not_deprecated = self.assert_not_deprecated
@@ -183,9 +178,9 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase):
assert_not_deprecated(lambda: a[:, 0,:])
assert_not_deprecated(lambda: a[:,:,:])
-
def test_slicing(self):
a = np.array([[5]])
+
def assert_deprecated(*args, **kwargs):
self.assert_deprecated(*args, exceptions=(IndexError,), **kwargs)
@@ -217,7 +212,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase):
# should still get the DeprecationWarning if step = 0.
assert_deprecated(lambda: a[::0.0], function_fails=True)
-
def test_valid_slicing(self):
a = np.array([[[5]]])
assert_not_deprecated = self.assert_not_deprecated
@@ -231,7 +225,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase):
assert_not_deprecated(lambda: a[:2:2])
assert_not_deprecated(lambda: a[1:2:2])
-
def test_non_integer_argument_deprecations(self):
a = np.array([[5]])
@@ -240,7 +233,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase):
self.assert_deprecated(np.take, args=(a, [0], 1.))
self.assert_deprecated(np.take, args=(a, [0], np.float64(1.)))
-
def test_non_integer_sequence_multiplication(self):
# Numpy scalar sequence multiply should not work with non-integers
def mult(a, b):
@@ -248,7 +240,6 @@ class TestFloatNonIntegerArgumentDeprecation(_DeprecationTestCase):
self.assert_deprecated(mult, args=([1], np.float_(3)))
self.assert_not_deprecated(mult, args=([1], np.int_(3)))
-
def test_reduce_axis_float_index(self):
d = np.zeros((3,3,3))
self.assert_deprecated(np.min, args=(d, 0.5))
@@ -303,7 +294,6 @@ class TestArrayToIndexDeprecation(_DeprecationTestCase):
# Check slicing. Normal indexing checks arrays specifically.
self.assert_deprecated(lambda: a[a:a:a], exceptions=(), num=3)
-
class TestNonIntegerArrayLike(_DeprecationTestCase):
"""Tests that array likes, i.e. lists give a deprecation warning
when they cannot be safely cast to an integer.
@@ -320,7 +310,6 @@ class TestNonIntegerArrayLike(_DeprecationTestCase):
self.assert_not_deprecated(a.__getitem__, ([],))
-
def test_boolean_futurewarning(self):
a = np.arange(10)
with warnings.catch_warnings():
@@ -378,12 +367,13 @@ class TestRankDeprecation(_DeprecationTestCase):
"""Test that np.rank is deprecated. The function should simply be
removed. The VisibleDeprecationWarning may become unnecessary.
"""
+
def test(self):
a = np.arange(10)
assert_warns(np.VisibleDeprecationWarning, np.rank, a)
-class TestComparisonDepreactions(_DeprecationTestCase):
+class TestComparisonDeprecations(_DeprecationTestCase):
"""This tests the deprecation, for non-elementwise comparison logic.
This used to mean that when an error occured during element-wise comparison
(i.e. broadcasting) NotImplemented was returned, but also in the comparison
@@ -408,7 +398,6 @@ class TestComparisonDepreactions(_DeprecationTestCase):
b = np.array([1, np.array([1,2,3])], dtype=object)
self.assert_deprecated(op, args=(a, b), num=None)
-
def test_string(self):
# For two string arrays, strings always raised the broadcasting error:
a = np.array(['a', 'b'])
@@ -420,7 +409,6 @@ class TestComparisonDepreactions(_DeprecationTestCase):
# following works (and returns False) due to dtype mismatch:
a == []
-
def test_none_comparison(self):
# Test comparison of None, which should result in elementwise
# comparison in the future. [1, 2] == None should be [False, False].
@@ -455,14 +443,14 @@ class TestComparisonDepreactions(_DeprecationTestCase):
assert_(np.equal(np.datetime64('NaT'), None))
-class TestIdentityComparisonDepreactions(_DeprecationTestCase):
+class TestIdentityComparisonDeprecations(_DeprecationTestCase):
"""This tests the equal and not_equal object ufuncs identity check
deprecation. This was due to the usage of PyObject_RichCompareBool.
This tests that for example for `a = np.array([np.nan], dtype=object)`
`a == a` it is warned that False and not `np.nan is np.nan` is returned.
- Should be kept in sync with TestComparisonDepreactions and new tests
+ Should be kept in sync with TestComparisonDeprecations and new tests
added when the deprecation is over. Requires only removing of @identity@
(and blocks) from the ufunc loops.c.src of the OBJECT comparisons.
"""
@@ -488,11 +476,11 @@ class TestIdentityComparisonDepreactions(_DeprecationTestCase):
np.less_equal(a, a)
np.greater_equal(a, a)
-
def test_comparison_error(self):
class FunkyType(object):
def __eq__(self, other):
raise TypeError("I won't compare")
+
def __ne__(self, other):
raise TypeError("I won't compare")
@@ -500,7 +488,6 @@ class TestIdentityComparisonDepreactions(_DeprecationTestCase):
self.assert_deprecated(np.equal, args=(a, a))
self.assert_deprecated(np.not_equal, args=(a, a))
-
def test_bool_error(self):
# The comparison result cannot be interpreted as a bool
a = np.array([np.array([1, 2, 3]), None], dtype=object)
@@ -508,5 +495,18 @@ class TestIdentityComparisonDepreactions(_DeprecationTestCase):
self.assert_deprecated(np.not_equal, args=(a, a))
+class TestAlterdotRestoredotDeprecations(_DeprecationTestCase):
+ """The alterdot/restoredot functions are deprecated.
+
+ These functions no longer do anything in numpy 1.10, so should not be
+ used.
+
+ """
+
+ def test_alterdot_restoredot_deprecation(self):
+ self.assert_deprecated(np.alterdot)
+ self.assert_deprecated(np.restoredot)
+
+
if __name__ == "__main__":
run_module_suite()