summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-11-21 21:53:10 -0600
committerGitHub <noreply@github.com>2018-11-21 21:53:10 -0600
commitf01924f92208cc8b40c5830bb240b3c26cbaea08 (patch)
tree850794921aef547c14a48be870aca3574d4f5904 /numpy
parent4f1541e1cb68beb3049a21cbdec6e3d30c2afbbb (diff)
parent73322a0ba923ff6e6429463ba1e31acc69ce1ea7 (diff)
downloadnumpy-f01924f92208cc8b40c5830bb240b3c26cbaea08.tar.gz
Merge pull request #12431 from eric-wieser/ndpointer-return
BUG/ENH: Fix use of ndpointer in return values
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/_multiarray_tests.c.src13
-rw-r--r--numpy/ctypeslib.py57
-rw-r--r--numpy/tests/test_ctypeslib.py75
3 files changed, 120 insertions, 25 deletions
diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src
index b98c3afb3..f05ee1431 100644
--- a/numpy/core/src/multiarray/_multiarray_tests.c.src
+++ b/numpy/core/src/multiarray/_multiarray_tests.c.src
@@ -11,6 +11,13 @@
#include "npy_extint128.h"
#include "common.h"
+
+#if defined(MS_WIN32) || defined(__CYGWIN__)
+#define EXPORT(x) __declspec(dllexport) x
+#else
+#define EXPORT(x) x
+#endif
+
#define ARRAY_SIZE(a) (sizeof(a)/sizeof(a[0]))
/* test PyArray_IsPythonScalar, before including private py3 compat header */
@@ -31,6 +38,12 @@ IsPythonScalar(PyObject * dummy, PyObject *args)
#include "npy_pycompat.h"
+/** Function to test calling via ctypes */
+EXPORT(void*) forward_pointer(void *x)
+{
+ return x;
+}
+
/*
* TODO:
* - Handle mode
diff --git a/numpy/ctypeslib.py b/numpy/ctypeslib.py
index 1158a5c85..11368587f 100644
--- a/numpy/ctypeslib.py
+++ b/numpy/ctypeslib.py
@@ -55,7 +55,9 @@ __all__ = ['load_library', 'ndpointer', 'test', 'ctypes_load_library',
'c_intp', 'as_ctypes', 'as_array']
import os
-from numpy import integer, ndarray, dtype as _dtype, deprecate, array
+from numpy import (
+ integer, ndarray, dtype as _dtype, deprecate, array, frombuffer
+)
from numpy.core.multiarray import _flagdict, flagsobj
try:
@@ -175,24 +177,6 @@ def _flags_fromnum(num):
class _ndptr(_ndptr_base):
-
- def _check_retval_(self):
- """This method is called when this class is used as the .restype
- attribute for a shared-library function. It constructs a numpy
- array from a void pointer."""
- return array(self)
-
- @property
- def __array_interface__(self):
- return {'descr': self._dtype_.descr,
- '__ref': self,
- 'strides': None,
- 'shape': self._shape_,
- 'version': 3,
- 'typestr': self._dtype_.descr[0][1],
- 'data': (self.value, False),
- }
-
@classmethod
def from_param(cls, obj):
if not isinstance(obj, ndarray):
@@ -213,6 +197,34 @@ class _ndptr(_ndptr_base):
return obj.ctypes
+class _concrete_ndptr(_ndptr):
+ """
+ Like _ndptr, but with `_shape_` and `_dtype_` specified.
+
+ Notably, this means the pointer has enough information to reconstruct
+ the array, which is not generally true.
+ """
+ def _check_retval_(self):
+ """
+ This method is called when this class is used as the .restype
+ attribute for a shared-library function, to automatically wrap the
+ pointer into an array.
+ """
+ return self.contents
+
+ @property
+ def contents(self):
+ """
+ Get an ndarray viewing the data pointed to by this pointer.
+
+ This mirrors the `contents` attribute of a normal ctypes pointer
+ """
+ full_dtype = _dtype((self._dtype_, self._shape_))
+ full_ctype = ctypes.c_char * full_dtype.itemsize
+ buffer = ctypes.cast(self, ctypes.POINTER(full_ctype)).contents
+ return frombuffer(buffer, dtype=full_dtype).squeeze(axis=0)
+
+
# Factory for an array-checking class with from_param defined for
# use with ctypes argtypes mechanism
_pointer_type_cache = {}
@@ -320,7 +332,12 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
if flags is not None:
name += "_"+"_".join(flags)
- klass = type("ndpointer_%s"%name, (_ndptr,),
+ if dtype is not None and shape is not None:
+ base = _concrete_ndptr
+ else:
+ base = _ndptr
+
+ klass = type("ndpointer_%s"%name, (base,),
{"_dtype_": dtype,
"_shape_" : shape,
"_ndim_" : ndim,
diff --git a/numpy/tests/test_ctypeslib.py b/numpy/tests/test_ctypeslib.py
index a6d73b152..53b75db07 100644
--- a/numpy/tests/test_ctypeslib.py
+++ b/numpy/tests/test_ctypeslib.py
@@ -9,20 +9,30 @@ from numpy.distutils.misc_util import get_shared_lib_extension
from numpy.testing import assert_, assert_array_equal, assert_raises, assert_equal
try:
+ import ctypes
+except ImportError:
+ ctypes = None
+else:
cdll = None
+ test_cdll = None
if hasattr(sys, 'gettotalrefcount'):
try:
cdll = load_library('_multiarray_umath_d', np.core._multiarray_umath.__file__)
except OSError:
pass
+ try:
+ test_cdll = load_library('_multiarray_tests', np.core._multiarray_tests.__file__)
+ except OSError:
+ pass
if cdll is None:
cdll = load_library('_multiarray_umath', np.core._multiarray_umath.__file__)
- _HAS_CTYPE = True
-except ImportError:
- _HAS_CTYPE = False
+ if test_cdll is None:
+ test_cdll = load_library('_multiarray_tests', np.core._multiarray_tests.__file__)
+
+ c_forward_pointer = test_cdll.forward_pointer
-@pytest.mark.skipif(not _HAS_CTYPE,
+@pytest.mark.skipif(ctypes is None,
reason="ctypes not available in this python")
@pytest.mark.skipif(sys.platform == 'cygwin',
reason="Known to fail on cygwin")
@@ -117,8 +127,63 @@ class TestNdpointer(object):
assert_(ndpointer(shape=2) is not ndpointer(ndim=2))
assert_(ndpointer(ndim=2) is not ndpointer(shape=2))
+@pytest.mark.skipif(ctypes is None,
+ reason="ctypes not available on this python installation")
+class TestNdpointerCFunc(object):
+ def test_arguments(self):
+ """ Test that arguments are coerced from arrays """
+ c_forward_pointer.restype = ctypes.c_void_p
+ c_forward_pointer.argtypes = (ndpointer(ndim=2),)
+
+ c_forward_pointer(np.zeros((2, 3)))
+ # too many dimensions
+ assert_raises(
+ ctypes.ArgumentError, c_forward_pointer, np.zeros((2, 3, 4)))
+
+ @pytest.mark.parametrize(
+ 'dt', [
+ float,
+ np.dtype(dict(
+ formats=['<i4', '<i4'],
+ names=['a', 'b'],
+ offsets=[0, 2],
+ itemsize=6
+ ))
+ ], ids=[
+ 'float',
+ 'overlapping-fields'
+ ]
+ )
+ def test_return(self, dt):
+ """ Test that return values are coerced to arrays """
+ arr = np.zeros((2, 3), dt)
+ ptr_type = ndpointer(shape=arr.shape, dtype=arr.dtype)
+
+ c_forward_pointer.restype = ptr_type
+ c_forward_pointer.argtypes = (ptr_type,)
+
+ # check that the arrays are equivalent views on the same data
+ arr2 = c_forward_pointer(arr)
+ assert_equal(arr2.dtype, arr.dtype)
+ assert_equal(arr2.shape, arr.shape)
+ assert_equal(
+ arr2.__array_interface__['data'],
+ arr.__array_interface__['data']
+ )
+
+ def test_vague_return_value(self):
+ """ Test that vague ndpointer return values do not promote to arrays """
+ arr = np.zeros((2, 3))
+ ptr_type = ndpointer(dtype=arr.dtype)
+
+ c_forward_pointer.restype = ptr_type
+ c_forward_pointer.argtypes = (ptr_type,)
+
+ ret = c_forward_pointer(arr)
+ assert_(isinstance(ret, ptr_type))
+
-@pytest.mark.skipif(not _HAS_CTYPE,
+@pytest.mark.skipif(ctypes is None,
reason="ctypes not available on this python installation")
class TestAsArray(object):
def test_array(self):