diff options
author | Eli Stevens <elis@doselab.com> | 2011-03-24 21:50:27 -0700 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-04-02 11:22:51 -0700 |
commit | f9c7bde68655de06546f5c8dba30314888e409cc (patch) | |
tree | 50a68d8b895260ac905c5ef08020121849da6abd /numpy/core | |
parent | 65b77ee94131bf8365d8a6dba6fa19da1269339c (diff) | |
download | numpy-f9c7bde68655de06546f5c8dba30314888e409cc.tar.gz |
ENH: Changes (and tests) to allow exporting half-floats through the buffer interface. (#1789)
Code:
Added NPY_HALF to switch (descr->type_num) in _buffer_format_string.
Added 'e' keys to the _pep3118_native_map and _pep3118_standard_map.
Tests:
Added entries to the generic round-trip tests.
Added specialized half-float test that round-trips example values from the wikipedia page.
http://en.wikipedia.org/wiki/Half_precision_floating-point_format
http://mail.scipy.org/pipermail/numpy-discussion/2011-March/055795.html
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/_internal.py | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/buffer.c | 1 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 35 |
3 files changed, 32 insertions, 6 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 67ca1716f..5298f412b 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -363,6 +363,7 @@ _pep3118_native_map = { 'L': 'L', 'q': 'q', 'Q': 'Q', + 'e': 'e', 'f': 'f', 'd': 'd', 'g': 'g', @@ -388,6 +389,7 @@ _pep3118_standard_map = { 'L': 'u4', 'q': 'i8', 'Q': 'u8', + 'e': 'f2', 'f': 'f', 'd': 'd', 'Zf': 'F', diff --git a/numpy/core/src/multiarray/buffer.c b/numpy/core/src/multiarray/buffer.c index 9f1ddc1f7..9bc45a76f 100644 --- a/numpy/core/src/multiarray/buffer.c +++ b/numpy/core/src/multiarray/buffer.c @@ -358,6 +358,7 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str, break; case NPY_LONGLONG: if (_append_char(str, 'q')) return -1; break; case NPY_ULONGLONG: if (_append_char(str, 'Q')) return -1; break; + case NPY_HALF: if (_append_char(str, 'e')) return -1; break; case NPY_FLOAT: if (_append_char(str, 'f')) return -1; break; case NPY_DOUBLE: if (_append_char(str, 'd')) return -1; break; case NPY_LONGDOUBLE: if (_append_char(str, 'g')) return -1; break; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 7cdd45eb0..e2010b6c2 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2008,9 +2008,11 @@ if sys.version_info >= (2, 6): ('l', 'S4'), ('m', 'U4'), ('n', 'V3'), - ('o', '?')] + ('o', '?'), + ('p', np.half), + ] x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - asbytes('aaaa'), 'bbbb', asbytes('xxx'), True)], + asbytes('aaaa'), 'bbbb', asbytes('xxx'), True, 1.0)], dtype=dt) self._check_roundtrip(x) @@ -2042,6 +2044,25 @@ if sys.version_info >= (2, 6): x = np.array([1,2,3], dtype='<q') assert_raises(ValueError, self._check_roundtrip, x) + def test_roundtrip_half(self): + half_list = [ + 1.0, + -2.0, + 6.5504 * 10**4, # (max half precision) + 2**-14, # ~= 6.10352 * 10**-5 (minimum positive normal) + 2**-24, # ~= 5.96046 * 10**-8 (minimum strictly positive subnormal) + 0.0, + -0.0, + float('+inf'), + float('-inf'), + 0.333251953125, # ~= 1/3 + ] + + x = np.array(half_list, dtype='>e') + self._check_roundtrip(x) + x = np.array(half_list, dtype='<e') + self._check_roundtrip(x) + def test_export_simple_1d(self): x = np.array([1,2,3,4,5], dtype='i') y = memoryview(x) @@ -2092,9 +2113,11 @@ if sys.version_info >= (2, 6): ('l', 'S4'), ('m', 'U4'), ('n', 'V3'), - ('o', '?')] + ('o', '?'), + ('p', np.half), + ] x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - asbytes('aaaa'), 'bbbb', asbytes(' '), True)], + asbytes('aaaa'), 'bbbb', asbytes(' '), True, 1.0)], dtype=dt) y = memoryview(x) assert_equal(y.shape, (1,)) @@ -2103,9 +2126,9 @@ if sys.version_info >= (2, 6): sz = sum([dtype(b).itemsize for a, b in dt]) if dtype('l').itemsize == 4: - assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}') + assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}') else: - assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}') + assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}') assert_equal(y.strides, (sz,)) assert_equal(y.itemsize, sz) |