diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-11-21 21:53:10 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-11-21 21:53:10 -0600 |
commit | f01924f92208cc8b40c5830bb240b3c26cbaea08 (patch) | |
tree | 850794921aef547c14a48be870aca3574d4f5904 /numpy | |
parent | 4f1541e1cb68beb3049a21cbdec6e3d30c2afbbb (diff) | |
parent | 73322a0ba923ff6e6429463ba1e31acc69ce1ea7 (diff) | |
download | numpy-f01924f92208cc8b40c5830bb240b3c26cbaea08.tar.gz |
Merge pull request #12431 from eric-wieser/ndpointer-return
BUG/ENH: Fix use of ndpointer in return values
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/_multiarray_tests.c.src | 13 | ||||
-rw-r--r-- | numpy/ctypeslib.py | 57 | ||||
-rw-r--r-- | numpy/tests/test_ctypeslib.py | 75 |
3 files changed, 120 insertions, 25 deletions
diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src index b98c3afb3..f05ee1431 100644 --- a/numpy/core/src/multiarray/_multiarray_tests.c.src +++ b/numpy/core/src/multiarray/_multiarray_tests.c.src @@ -11,6 +11,13 @@ #include "npy_extint128.h" #include "common.h" + +#if defined(MS_WIN32) || defined(__CYGWIN__) +#define EXPORT(x) __declspec(dllexport) x +#else +#define EXPORT(x) x +#endif + #define ARRAY_SIZE(a) (sizeof(a)/sizeof(a[0])) /* test PyArray_IsPythonScalar, before including private py3 compat header */ @@ -31,6 +38,12 @@ IsPythonScalar(PyObject * dummy, PyObject *args) #include "npy_pycompat.h" +/** Function to test calling via ctypes */ +EXPORT(void*) forward_pointer(void *x) +{ + return x; +} + /* * TODO: * - Handle mode diff --git a/numpy/ctypeslib.py b/numpy/ctypeslib.py index 1158a5c85..11368587f 100644 --- a/numpy/ctypeslib.py +++ b/numpy/ctypeslib.py @@ -55,7 +55,9 @@ __all__ = ['load_library', 'ndpointer', 'test', 'ctypes_load_library', 'c_intp', 'as_ctypes', 'as_array'] import os -from numpy import integer, ndarray, dtype as _dtype, deprecate, array +from numpy import ( + integer, ndarray, dtype as _dtype, deprecate, array, frombuffer +) from numpy.core.multiarray import _flagdict, flagsobj try: @@ -175,24 +177,6 @@ def _flags_fromnum(num): class _ndptr(_ndptr_base): - - def _check_retval_(self): - """This method is called when this class is used as the .restype - attribute for a shared-library function. It constructs a numpy - array from a void pointer.""" - return array(self) - - @property - def __array_interface__(self): - return {'descr': self._dtype_.descr, - '__ref': self, - 'strides': None, - 'shape': self._shape_, - 'version': 3, - 'typestr': self._dtype_.descr[0][1], - 'data': (self.value, False), - } - @classmethod def from_param(cls, obj): if not isinstance(obj, ndarray): @@ -213,6 +197,34 @@ class _ndptr(_ndptr_base): return obj.ctypes +class _concrete_ndptr(_ndptr): + """ + Like _ndptr, but with `_shape_` and `_dtype_` specified. + + Notably, this means the pointer has enough information to reconstruct + the array, which is not generally true. + """ + def _check_retval_(self): + """ + This method is called when this class is used as the .restype + attribute for a shared-library function, to automatically wrap the + pointer into an array. + """ + return self.contents + + @property + def contents(self): + """ + Get an ndarray viewing the data pointed to by this pointer. + + This mirrors the `contents` attribute of a normal ctypes pointer + """ + full_dtype = _dtype((self._dtype_, self._shape_)) + full_ctype = ctypes.c_char * full_dtype.itemsize + buffer = ctypes.cast(self, ctypes.POINTER(full_ctype)).contents + return frombuffer(buffer, dtype=full_dtype).squeeze(axis=0) + + # Factory for an array-checking class with from_param defined for # use with ctypes argtypes mechanism _pointer_type_cache = {} @@ -320,7 +332,12 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None): if flags is not None: name += "_"+"_".join(flags) - klass = type("ndpointer_%s"%name, (_ndptr,), + if dtype is not None and shape is not None: + base = _concrete_ndptr + else: + base = _ndptr + + klass = type("ndpointer_%s"%name, (base,), {"_dtype_": dtype, "_shape_" : shape, "_ndim_" : ndim, diff --git a/numpy/tests/test_ctypeslib.py b/numpy/tests/test_ctypeslib.py index a6d73b152..53b75db07 100644 --- a/numpy/tests/test_ctypeslib.py +++ b/numpy/tests/test_ctypeslib.py @@ -9,20 +9,30 @@ from numpy.distutils.misc_util import get_shared_lib_extension from numpy.testing import assert_, assert_array_equal, assert_raises, assert_equal try: + import ctypes +except ImportError: + ctypes = None +else: cdll = None + test_cdll = None if hasattr(sys, 'gettotalrefcount'): try: cdll = load_library('_multiarray_umath_d', np.core._multiarray_umath.__file__) except OSError: pass + try: + test_cdll = load_library('_multiarray_tests', np.core._multiarray_tests.__file__) + except OSError: + pass if cdll is None: cdll = load_library('_multiarray_umath', np.core._multiarray_umath.__file__) - _HAS_CTYPE = True -except ImportError: - _HAS_CTYPE = False + if test_cdll is None: + test_cdll = load_library('_multiarray_tests', np.core._multiarray_tests.__file__) + + c_forward_pointer = test_cdll.forward_pointer -@pytest.mark.skipif(not _HAS_CTYPE, +@pytest.mark.skipif(ctypes is None, reason="ctypes not available in this python") @pytest.mark.skipif(sys.platform == 'cygwin', reason="Known to fail on cygwin") @@ -117,8 +127,63 @@ class TestNdpointer(object): assert_(ndpointer(shape=2) is not ndpointer(ndim=2)) assert_(ndpointer(ndim=2) is not ndpointer(shape=2)) +@pytest.mark.skipif(ctypes is None, + reason="ctypes not available on this python installation") +class TestNdpointerCFunc(object): + def test_arguments(self): + """ Test that arguments are coerced from arrays """ + c_forward_pointer.restype = ctypes.c_void_p + c_forward_pointer.argtypes = (ndpointer(ndim=2),) + + c_forward_pointer(np.zeros((2, 3))) + # too many dimensions + assert_raises( + ctypes.ArgumentError, c_forward_pointer, np.zeros((2, 3, 4))) + + @pytest.mark.parametrize( + 'dt', [ + float, + np.dtype(dict( + formats=['<i4', '<i4'], + names=['a', 'b'], + offsets=[0, 2], + itemsize=6 + )) + ], ids=[ + 'float', + 'overlapping-fields' + ] + ) + def test_return(self, dt): + """ Test that return values are coerced to arrays """ + arr = np.zeros((2, 3), dt) + ptr_type = ndpointer(shape=arr.shape, dtype=arr.dtype) + + c_forward_pointer.restype = ptr_type + c_forward_pointer.argtypes = (ptr_type,) + + # check that the arrays are equivalent views on the same data + arr2 = c_forward_pointer(arr) + assert_equal(arr2.dtype, arr.dtype) + assert_equal(arr2.shape, arr.shape) + assert_equal( + arr2.__array_interface__['data'], + arr.__array_interface__['data'] + ) + + def test_vague_return_value(self): + """ Test that vague ndpointer return values do not promote to arrays """ + arr = np.zeros((2, 3)) + ptr_type = ndpointer(dtype=arr.dtype) + + c_forward_pointer.restype = ptr_type + c_forward_pointer.argtypes = (ptr_type,) + + ret = c_forward_pointer(arr) + assert_(isinstance(ret, ptr_type)) + -@pytest.mark.skipif(not _HAS_CTYPE, +@pytest.mark.skipif(ctypes is None, reason="ctypes not available on this python installation") class TestAsArray(object): def test_array(self): |