summaryrefslogtreecommitdiff
path: root/numpy/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/tests')
-rw-r--r--numpy/tests/test_ctypeslib.py77
1 files changed, 67 insertions, 10 deletions
diff --git a/numpy/tests/test_ctypeslib.py b/numpy/tests/test_ctypeslib.py
index 0f0d6dbc4..75ce9c8ca 100644
--- a/numpy/tests/test_ctypeslib.py
+++ b/numpy/tests/test_ctypeslib.py
@@ -4,9 +4,9 @@ import sys
import pytest
import numpy as np
-from numpy.ctypeslib import ndpointer, load_library
+from numpy.ctypeslib import ndpointer, load_library, as_array
from numpy.distutils.misc_util import get_shared_lib_extension
-from numpy.testing import assert_, assert_raises
+from numpy.testing import assert_, assert_array_equal, assert_raises, assert_equal
try:
cdll = None
@@ -21,11 +21,12 @@ try:
except ImportError:
_HAS_CTYPE = False
+
+@pytest.mark.skipif(not _HAS_CTYPE,
+ reason="ctypes not available in this python")
+@pytest.mark.skipif(sys.platform == 'cygwin',
+ reason="Known to fail on cygwin")
class TestLoadLibrary(object):
- @pytest.mark.skipif(not _HAS_CTYPE,
- reason="ctypes not available in this python")
- @pytest.mark.skipif(sys.platform == 'cygwin',
- reason="Known to fail on cygwin")
def test_basic(self):
try:
# Should succeed
@@ -35,10 +36,6 @@ class TestLoadLibrary(object):
" (import error was: %s)" % str(e))
print(msg)
- @pytest.mark.skipif(not _HAS_CTYPE,
- reason="ctypes not available in this python")
- @pytest.mark.skipif(sys.platform == 'cygwin',
- reason="Known to fail on cygwin")
def test_basic2(self):
# Regression for #801: load_library with a full library name
# (including extension) does not work.
@@ -54,6 +51,7 @@ class TestLoadLibrary(object):
" (import error was: %s)" % str(e))
print(msg)
+
class TestNdpointer(object):
def test_dtype(self):
dt = np.intc
@@ -113,3 +111,62 @@ class TestNdpointer(object):
a1 = ndpointer(dtype=np.float64)
a2 = ndpointer(dtype=np.float64)
assert_(a1 == a2)
+
+
+@pytest.mark.skipif(not _HAS_CTYPE,
+ reason="ctypes not available on this python installation")
+class TestAsArray(object):
+ def test_array(self):
+ from ctypes import c_int
+
+ pair_t = c_int * 2
+ a = as_array(pair_t(1, 2))
+ assert_equal(a.shape, (2,))
+ assert_array_equal(a, np.array([1, 2]))
+ a = as_array((pair_t * 3)(pair_t(1, 2), pair_t(3, 4), pair_t(5, 6)))
+ assert_equal(a.shape, (3, 2))
+ assert_array_equal(a, np.array([[1, 2], [3, 4], [5, 6]]))
+
+ def test_pointer(self):
+ from ctypes import c_int, cast, POINTER
+
+ p = cast((c_int * 10)(*range(10)), POINTER(c_int))
+
+ a = as_array(p, shape=(10,))
+ assert_equal(a.shape, (10,))
+ assert_array_equal(a, np.arange(10))
+
+ a = as_array(p, shape=(2, 5))
+ assert_equal(a.shape, (2, 5))
+ assert_array_equal(a, np.arange(10).reshape((2, 5)))
+
+ # shape argument is required
+ assert_raises(TypeError, as_array, p)
+
+ def test_struct_array_pointer(self):
+ from ctypes import c_int16, Structure, pointer
+
+ class Struct(Structure):
+ _fields_ = [('a', c_int16)]
+
+ Struct3 = 3 * Struct
+
+ c_array = (2 * Struct3)(
+ Struct3(Struct(a=1), Struct(a=2), Struct(a=3)),
+ Struct3(Struct(a=4), Struct(a=5), Struct(a=6))
+ )
+
+ expected = np.array([
+ [(1,), (2,), (3,)],
+ [(4,), (5,), (6,)],
+ ], dtype=[('a', np.int16)])
+
+ def check(x):
+ assert_equal(x.dtype, expected.dtype)
+ assert_equal(x, expected)
+
+ # all of these should be equivalent
+ check(as_array(c_array))
+ check(as_array(pointer(c_array), shape=()))
+ check(as_array(pointer(c_array[0]), shape=(2,)))
+ check(as_array(pointer(c_array[0][0]), shape=(2, 3)))