diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-12-08 03:23:10 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-12-08 03:23:10 +0000 |
commit | db487c204c9df0472642d14d80632382463ecbcc (patch) | |
tree | e066024d8d2388ff525556dc0f11478174294f75 /numpy/core | |
parent | e3a1550ff1417599f68207f60ea622404d281178 (diff) | |
download | numpy-db487c204c9df0472642d14d80632382463ecbcc.tar.gz |
Add feature so that user-defined types can be recognized when instances of scalars of those types are present in the nested sequence.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/arrayprint.py | 20 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 62 |
2 files changed, 54 insertions, 28 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 42ec4cadb..f82f5ee21 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -144,23 +144,23 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', try: format_function = a._format except AttributeError: - dtype = a.dtype.type - if issubclass(dtype, _nt.bool): + dtypeobj = a.dtype.type + if issubclass(dtypeobj, _nt.bool): format = "%s" format_function = lambda x: format % x - if issubclass(dtype, _nt.integer): + if issubclass(dtypeobj, _nt.integer): max_str_len = max(len(str(max_reduce(data))), len(str(min_reduce(data)))) format = '%' + str(max_str_len) + 'd' format_function = lambda x: _formatInteger(x, format) - elif issubclass(dtype, _nt.floating): - if issubclass(dtype, _nt.longfloat): + elif issubclass(dtypeobj, _nt.floating): + if issubclass(dtypeobj, _nt.longfloat): format_function = _longfloatFormatter(precision) else: format = _floatFormat(data, precision, suppress_small) format_function = lambda x: _formatFloat(x, format) - elif issubclass(dtype, _nt.complexfloating): - if issubclass(dtype, _nt.clongfloat): + elif issubclass(dtypeobj, _nt.complexfloating): + if issubclass(dtypeobj, _nt.clongfloat): format_function = _clongfloatFormatter(precision) else: real_format = _floatFormat( @@ -169,13 +169,13 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', data.imag, precision, suppress_small, sign=1) format_function = lambda x: \ _formatComplex(x, real_format, imag_format) - elif issubclass(dtype, _nt.unicode_) or \ - issubclass(dtype, _nt.string_): + elif issubclass(dtypeobj, _nt.unicode_) or \ + issubclass(dtypeobj, _nt.string_): format = "%s" format_function = lambda x: repr(x) else: format = '%s' - format_function = lambda x: format % str(x) + format_function = lambda x: str(x) next_line_prefix = " " # skip over "[" next_line_prefix += " "*len(prefix) # skip over array( diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index aef04b4c7..eaed75f14 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -6918,27 +6918,53 @@ _array_find_python_scalar_type(PyObject *op) { if (PyFloat_Check(op)) { return PyArray_DescrFromType(PyArray_DOUBLE); - } else if (PyComplex_Check(op)) { - return PyArray_DescrFromType(PyArray_CDOUBLE); - } else if (PyInt_Check(op)) { - /* bools are a subclass of int */ - if (PyBool_Check(op)) { - return PyArray_DescrFromType(PyArray_BOOL); - } else { - return PyArray_DescrFromType(PyArray_LONG); - } - } else if (PyLong_Check(op)) { - /* if integer can fit into a longlong then return that - */ - if ((PyLong_AsLongLong(op) == -1) && PyErr_Occurred()) { - PyErr_Clear(); - return PyArray_DescrFromType(PyArray_OBJECT); - } - return PyArray_DescrFromType(PyArray_LONGLONG); + } + else if (PyComplex_Check(op)) { + return PyArray_DescrFromType(PyArray_CDOUBLE); + } + else if (PyInt_Check(op)) { + /* bools are a subclass of int */ + if (PyBool_Check(op)) { + return PyArray_DescrFromType(PyArray_BOOL); + } else { + return PyArray_DescrFromType(PyArray_LONG); + } + } + else if (PyLong_Check(op)) { + /* if integer can fit into a longlong then return that + */ + if ((PyLong_AsLongLong(op) == -1) && PyErr_Occurred()) { + PyErr_Clear(); + return PyArray_DescrFromType(PyArray_OBJECT); + } + return PyArray_DescrFromType(PyArray_LONGLONG); } return NULL; } +static PyArray_Descr * +_use_default_type(PyObject *op) +{ + int typenum, l; + PyObject *type; + + typenum = -1; + l = 0; + type = (PyObject *)op->ob_type; + while (l < PyArray_NUMUSERTYPES) { + if (type == (PyObject *)(userdescrs[l]->typeobj)) { + typenum = l + PyArray_USERDEF; + break; + } + l++; + } + if (typenum == -1) { + typenum = PyArray_OBJECT; + } + return PyArray_DescrFromType(typenum); +} + + /* op is an object to be converted to an ndarray. minitype is the minimum type-descriptor needed. @@ -7086,7 +7112,7 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max) deflt: - chktype = PyArray_DescrFromType(PyArray_OBJECT); + chktype = _use_default_type(op); finish: |