diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/multiarray/buffer.c | 115 | ||||
| -rw-r--r-- | numpy/core/tests/test_multiarray.py | 23 |
2 files changed, 138 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/buffer.c b/numpy/core/src/multiarray/buffer.c index 3ae3c7d0d..975891b3a 100644 --- a/numpy/core/src/multiarray/buffer.c +++ b/numpy/core/src/multiarray/buffer.c @@ -767,11 +767,18 @@ NPY_NO_EXPORT PyBufferProcs array_as_buffer = { * Convert PEP 3118 format string to PyArray_Descr */ +static int +_descriptor_from_pep3118_format_fast(char *s, PyObject **result); + +static int +_pep3118_letter_to_type(char letter, int native, int complex); + NPY_NO_EXPORT PyArray_Descr* _descriptor_from_pep3118_format(char *s) { char *buf, *p; int in_name = 0; + int obtained; PyObject *descr; PyObject *str; PyObject *_numpy_internal; @@ -780,6 +787,12 @@ _descriptor_from_pep3118_format(char *s) return PyArray_DescrNewFromType(NPY_BYTE); } + /* Fast path */ + obtained = _descriptor_from_pep3118_format_fast(s, &descr); + if (obtained) { + return (PyArray_Descr*)descr; + } + /* Strip whitespace, except from field names */ buf = malloc(strlen(s) + 1); if (buf == NULL) { @@ -828,3 +841,105 @@ _descriptor_from_pep3118_format(char *s) } return (PyArray_Descr*)descr; } + +/* + * Fast path for parsing buffer strings corresponding to simple types. + * + * Currently, this deals only with single-element data types. + */ + +static int +_descriptor_from_pep3118_format_fast(char *s, PyObject **result) +{ + PyArray_Descr *descr; + + int is_standard_size = 0; + char byte_order = '='; + int is_complex = 0; + + int type_num = NPY_BYTE; + int item_seen = 0; + + for (; *s != '\0'; ++s) { + is_complex = 0; + switch (*s) { + case '@': + case '^': + /* ^ means no alignment; doesn't matter for a single element */ + byte_order = '='; + is_standard_size = 0; + break; + case '<': + byte_order = '<'; + is_standard_size = 1; + break; + case '>': + case '!': + byte_order = '>'; + is_standard_size = 1; + break; + case '=': + byte_order = '='; + is_standard_size = 1; + break; + case 'Z': + is_complex = 1; + ++s; + default: + if (item_seen) { + /* Not a single-element data type */ + return 0; + } + type_num = _pep3118_letter_to_type(*s, !is_standard_size, + is_complex); + if (type_num < 0) { + /* Something unknown */ + return 0; + } + item_seen = 1; + break; + } + } + + if (!item_seen) { + return 0; + } + + descr = PyArray_DescrFromType(type_num); + if (byte_order == '=') { + *result = (PyObject*)descr; + } + else { + *result = (PyObject*)PyArray_DescrNewByteorder(descr, byte_order); + Py_DECREF(descr); + } + + return 1; +} + +static int +_pep3118_letter_to_type(char letter, int native, int complex) +{ + switch (letter) + { + case '?': return NPY_BOOL; + case 'b': return NPY_BYTE; + case 'B': return NPY_UBYTE; + case 'h': return native ? NPY_SHORT : NPY_INT16; + case 'H': return native ? NPY_USHORT : NPY_UINT16; + case 'i': return native ? NPY_INT : NPY_INT32; + case 'I': return native ? NPY_UINT : NPY_UINT32; + case 'l': return native ? NPY_LONG : NPY_INT32; + case 'L': return native ? NPY_ULONG : NPY_UINT32; + case 'q': return native ? NPY_LONGLONG : NPY_INT64; + case 'Q': return native ? NPY_ULONGLONG : NPY_UINT64; + case 'e': return NPY_HALF; + case 'f': return complex ? NPY_CFLOAT : NPY_FLOAT; + case 'd': return complex ? NPY_CDOUBLE : NPY_DOUBLE; + case 'g': return native ? (complex ? NPY_CLONGDOUBLE : NPY_LONGDOUBLE) : -1; + default: + /* Other unhandled cases */ + return -1; + } + return -1; +} diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 77da5543d..19e6b5d5d 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3239,6 +3239,29 @@ class TestNewBufferProtocol(object): x = np.array(half_list, dtype='<e') self._check_roundtrip(x) + def test_roundtrip_single_types(self): + for typ in np.typeDict.values(): + dtype = np.dtype(typ) + + if dtype.char in 'Mm': + # datetimes cannot be used in buffers + continue + if dtype.char == 'V': + # skip void + continue + + x = np.zeros(4, dtype=dtype) + self._check_roundtrip(x) + + if dtype.char not in 'qQgG': + dt = dtype.newbyteorder('<') + x = np.zeros(4, dtype=dt) + self._check_roundtrip(x) + + dt = dtype.newbyteorder('>') + x = np.zeros(4, dtype=dt) + self._check_roundtrip(x) + def test_export_simple_1d(self): x = np.array([1, 2, 3, 4, 5], dtype='i') y = memoryview(x) |
