From f9c7bde68655de06546f5c8dba30314888e409cc Mon Sep 17 00:00:00 2001 From: Eli Stevens Date: Thu, 24 Mar 2011 21:50:27 -0700 Subject: ENH: Changes (and tests) to allow exporting half-floats through the buffer interface. (#1789) Code: Added NPY_HALF to switch (descr->type_num) in _buffer_format_string. Added 'e' keys to the _pep3118_native_map and _pep3118_standard_map. Tests: Added entries to the generic round-trip tests. Added specialized half-float test that round-trips example values from the wikipedia page. http://en.wikipedia.org/wiki/Half_precision_floating-point_format http://mail.scipy.org/pipermail/numpy-discussion/2011-March/055795.html --- numpy/core/_internal.py | 2 ++ numpy/core/src/multiarray/buffer.c | 1 + numpy/core/tests/test_multiarray.py | 35 +++++++++++++++++++++++++++++------ 3 files changed, 32 insertions(+), 6 deletions(-) (limited to 'numpy') diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 67ca1716f..5298f412b 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -363,6 +363,7 @@ _pep3118_native_map = { 'L': 'L', 'q': 'q', 'Q': 'Q', + 'e': 'e', 'f': 'f', 'd': 'd', 'g': 'g', @@ -388,6 +389,7 @@ _pep3118_standard_map = { 'L': 'u4', 'q': 'i8', 'Q': 'u8', + 'e': 'f2', 'f': 'f', 'd': 'd', 'Zf': 'F', diff --git a/numpy/core/src/multiarray/buffer.c b/numpy/core/src/multiarray/buffer.c index 9f1ddc1f7..9bc45a76f 100644 --- a/numpy/core/src/multiarray/buffer.c +++ b/numpy/core/src/multiarray/buffer.c @@ -358,6 +358,7 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str, break; case NPY_LONGLONG: if (_append_char(str, 'q')) return -1; break; case NPY_ULONGLONG: if (_append_char(str, 'Q')) return -1; break; + case NPY_HALF: if (_append_char(str, 'e')) return -1; break; case NPY_FLOAT: if (_append_char(str, 'f')) return -1; break; case NPY_DOUBLE: if (_append_char(str, 'd')) return -1; break; case NPY_LONGDOUBLE: if (_append_char(str, 'g')) return -1; break; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 7cdd45eb0..e2010b6c2 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2008,9 +2008,11 @@ if sys.version_info >= (2, 6): ('l', 'S4'), ('m', 'U4'), ('n', 'V3'), - ('o', '?')] + ('o', '?'), + ('p', np.half), + ] x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - asbytes('aaaa'), 'bbbb', asbytes('xxx'), True)], + asbytes('aaaa'), 'bbbb', asbytes('xxx'), True, 1.0)], dtype=dt) self._check_roundtrip(x) @@ -2042,6 +2044,25 @@ if sys.version_info >= (2, 6): x = np.array([1,2,3], dtype='= (2, 6): ('l', 'S4'), ('m', 'U4'), ('n', 'V3'), - ('o', '?')] + ('o', '?'), + ('p', np.half), + ] x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - asbytes('aaaa'), 'bbbb', asbytes(' '), True)], + asbytes('aaaa'), 'bbbb', asbytes(' '), True, 1.0)], dtype=dt) y = memoryview(x) assert_equal(y.shape, (1,)) @@ -2103,9 +2126,9 @@ if sys.version_info >= (2, 6): sz = sum([dtype(b).itemsize for a, b in dt]) if dtype('l').itemsize == 4: - assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}') + assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}') else: - assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}') + assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}') assert_equal(y.strides, (sz,)) assert_equal(y.itemsize, sz) -- cgit v1.2.1