summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMircea-Akos Brumă <bruma.mircea.a@gmail.com>2018-11-19 02:10:35 +0100
committerEric Wieser <wieser.eric@gmail.com>2018-11-18 17:10:35 -0800
commit1466e788a43b8d4356fe35951bf0c3b0aedb554f (patch)
tree4b4f0b3f2ec0e1fc47188cfb0214e5303d7ccc17 /numpy
parent70674407108faef9ade4ccb283748200247b2b57 (diff)
downloadnumpy-1466e788a43b8d4356fe35951bf0c3b0aedb554f.tar.gz
ENH: Add support for `np.dtype(ctypes.Union)` (#12405)
Fixes #12273
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_dtype_ctypes.py21
-rw-r--r--numpy/core/tests/test_dtype.py49
2 files changed, 63 insertions, 7 deletions
diff --git a/numpy/core/_dtype_ctypes.py b/numpy/core/_dtype_ctypes.py
index ca365d2cb..4d5191aab 100644
--- a/numpy/core/_dtype_ctypes.py
+++ b/numpy/core/_dtype_ctypes.py
@@ -78,6 +78,22 @@ def dtype_from_ctypes_scalar(t):
return np.dtype(t._type_)
+def dtype_from_ctypes_union(t):
+ formats = []
+ offsets = []
+ names = []
+ for fname, ftyp in t._fields_:
+ names.append(fname)
+ formats.append(dtype_from_ctypes_type(ftyp))
+ offsets.append(0) # Union fields are offset to 0
+
+ return np.dtype(dict(
+ formats=formats,
+ offsets=offsets,
+ names=names,
+ itemsize=ctypes.sizeof(t)))
+
+
def dtype_from_ctypes_type(t):
"""
Construct a dtype object from a ctypes type
@@ -89,10 +105,7 @@ def dtype_from_ctypes_type(t):
elif issubclass(t, _ctypes.Structure):
return _from_ctypes_structure(t)
elif issubclass(t, _ctypes.Union):
- # TODO
- raise NotImplementedError(
- "conversion from ctypes.Union types like {} to dtype"
- .format(t.__name__))
+ return dtype_from_ctypes_union(t)
elif isinstance(t._type_, str):
return dtype_from_ctypes_scalar(t)
else:
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index f2e7f8f50..a39573495 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -807,9 +807,6 @@ class TestFromCTypes(object):
p_uint8 = ctypes.POINTER(ctypes.c_uint8)
assert_raises(TypeError, np.dtype, p_uint8)
- @pytest.mark.xfail(
- reason="Unions are not implemented",
- raises=NotImplementedError)
def test_union(self):
class Union(ctypes.Union):
_fields_ = [
@@ -824,6 +821,52 @@ class TestFromCTypes(object):
))
self.check(Union, expected)
+ def test_union_with_struct_packed(self):
+ class Struct(ctypes.Structure):
+ _pack_ = 1
+ _fields_ = [
+ ('one', ctypes.c_uint8),
+ ('two', ctypes.c_uint32)
+ ]
+
+ class Union(ctypes.Union):
+ _fields_ = [
+ ('a', ctypes.c_uint8),
+ ('b', ctypes.c_uint16),
+ ('c', ctypes.c_uint32),
+ ('d', Struct),
+ ]
+ expected = np.dtype(dict(
+ names=['a', 'b', 'c', 'd'],
+ formats=['u1', np.uint16, np.uint32, [('one', 'u1'), ('two', np.uint32)]],
+ offsets=[0, 0, 0, 0],
+ itemsize=ctypes.sizeof(Union)
+ ))
+ self.check(Union, expected)
+
+ def test_union_packed(self):
+ class Struct(ctypes.Structure):
+ _fields_ = [
+ ('one', ctypes.c_uint8),
+ ('two', ctypes.c_uint32)
+ ]
+ _pack_ = 1
+ class Union(ctypes.Union):
+ _pack_ = 1
+ _fields_ = [
+ ('a', ctypes.c_uint8),
+ ('b', ctypes.c_uint16),
+ ('c', ctypes.c_uint32),
+ ('d', Struct),
+ ]
+ expected = np.dtype(dict(
+ names=['a', 'b', 'c', 'd'],
+ formats=['u1', np.uint16, np.uint32, [('one', 'u1'), ('two', np.uint32)]],
+ offsets=[0, 0, 0, 0],
+ itemsize=ctypes.sizeof(Union)
+ ))
+ self.check(Union, expected)
+
def test_packed_structure(self):
class PackedStructure(ctypes.Structure):
_pack_ = 1