summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorEli Stevens <elis@doselab.com>2011-03-24 21:50:27 -0700
committerMark Wiebe <mwwiebe@gmail.com>2011-04-02 11:22:51 -0700
commitf9c7bde68655de06546f5c8dba30314888e409cc (patch)
tree50a68d8b895260ac905c5ef08020121849da6abd /numpy/core
parent65b77ee94131bf8365d8a6dba6fa19da1269339c (diff)
downloadnumpy-f9c7bde68655de06546f5c8dba30314888e409cc.tar.gz
ENH: Changes (and tests) to allow exporting half-floats through the buffer interface. (#1789)
Code: Added NPY_HALF to switch (descr->type_num) in _buffer_format_string. Added 'e' keys to the _pep3118_native_map and _pep3118_standard_map. Tests: Added entries to the generic round-trip tests. Added specialized half-float test that round-trips example values from the wikipedia page. http://en.wikipedia.org/wiki/Half_precision_floating-point_format http://mail.scipy.org/pipermail/numpy-discussion/2011-March/055795.html
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/_internal.py2
-rw-r--r--numpy/core/src/multiarray/buffer.c1
-rw-r--r--numpy/core/tests/test_multiarray.py35
3 files changed, 32 insertions, 6 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 67ca1716f..5298f412b 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -363,6 +363,7 @@ _pep3118_native_map = {
'L': 'L',
'q': 'q',
'Q': 'Q',
+ 'e': 'e',
'f': 'f',
'd': 'd',
'g': 'g',
@@ -388,6 +389,7 @@ _pep3118_standard_map = {
'L': 'u4',
'q': 'i8',
'Q': 'u8',
+ 'e': 'f2',
'f': 'f',
'd': 'd',
'Zf': 'F',
diff --git a/numpy/core/src/multiarray/buffer.c b/numpy/core/src/multiarray/buffer.c
index 9f1ddc1f7..9bc45a76f 100644
--- a/numpy/core/src/multiarray/buffer.c
+++ b/numpy/core/src/multiarray/buffer.c
@@ -358,6 +358,7 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str,
break;
case NPY_LONGLONG: if (_append_char(str, 'q')) return -1; break;
case NPY_ULONGLONG: if (_append_char(str, 'Q')) return -1; break;
+ case NPY_HALF: if (_append_char(str, 'e')) return -1; break;
case NPY_FLOAT: if (_append_char(str, 'f')) return -1; break;
case NPY_DOUBLE: if (_append_char(str, 'd')) return -1; break;
case NPY_LONGDOUBLE: if (_append_char(str, 'g')) return -1; break;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 7cdd45eb0..e2010b6c2 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2008,9 +2008,11 @@ if sys.version_info >= (2, 6):
('l', 'S4'),
('m', 'U4'),
('n', 'V3'),
- ('o', '?')]
+ ('o', '?'),
+ ('p', np.half),
+ ]
x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- asbytes('aaaa'), 'bbbb', asbytes('xxx'), True)],
+ asbytes('aaaa'), 'bbbb', asbytes('xxx'), True, 1.0)],
dtype=dt)
self._check_roundtrip(x)
@@ -2042,6 +2044,25 @@ if sys.version_info >= (2, 6):
x = np.array([1,2,3], dtype='<q')
assert_raises(ValueError, self._check_roundtrip, x)
+ def test_roundtrip_half(self):
+ half_list = [
+ 1.0,
+ -2.0,
+ 6.5504 * 10**4, # (max half precision)
+ 2**-14, # ~= 6.10352 * 10**-5 (minimum positive normal)
+ 2**-24, # ~= 5.96046 * 10**-8 (minimum strictly positive subnormal)
+ 0.0,
+ -0.0,
+ float('+inf'),
+ float('-inf'),
+ 0.333251953125, # ~= 1/3
+ ]
+
+ x = np.array(half_list, dtype='>e')
+ self._check_roundtrip(x)
+ x = np.array(half_list, dtype='<e')
+ self._check_roundtrip(x)
+
def test_export_simple_1d(self):
x = np.array([1,2,3,4,5], dtype='i')
y = memoryview(x)
@@ -2092,9 +2113,11 @@ if sys.version_info >= (2, 6):
('l', 'S4'),
('m', 'U4'),
('n', 'V3'),
- ('o', '?')]
+ ('o', '?'),
+ ('p', np.half),
+ ]
x = np.array([(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- asbytes('aaaa'), 'bbbb', asbytes(' '), True)],
+ asbytes('aaaa'), 'bbbb', asbytes(' '), True, 1.0)],
dtype=dt)
y = memoryview(x)
assert_equal(y.shape, (1,))
@@ -2103,9 +2126,9 @@ if sys.version_info >= (2, 6):
sz = sum([dtype(b).itemsize for a, b in dt])
if dtype('l').itemsize == 4:
- assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}')
+ assert_equal(y.format, 'T{b:a:=h:b:i:c:l:d:^q:dx:B:e:@H:f:=I:g:L:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}')
else:
- assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:}')
+ assert_equal(y.format, 'T{b:a:=h:b:i:c:q:d:^q:dx:B:e:@H:f:=I:g:Q:h:^Q:hx:=f:i:d:j:^g:k:=Zf:ix:Zd:jx:^Zg:kx:4s:l:=4w:m:3x:n:?:o:@e:p:}')
assert_equal(y.strides, (sz,))
assert_equal(y.itemsize, sz)