summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-12-08 03:23:10 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-12-08 03:23:10 +0000
commitdb487c204c9df0472642d14d80632382463ecbcc (patch)
treee066024d8d2388ff525556dc0f11478174294f75 /numpy/core
parente3a1550ff1417599f68207f60ea622404d281178 (diff)
downloadnumpy-db487c204c9df0472642d14d80632382463ecbcc.tar.gz
Add feature so that user-defined types can be recognized when instances of scalars of those types are present in the nested sequence.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/arrayprint.py20
-rw-r--r--numpy/core/src/arrayobject.c62
2 files changed, 54 insertions, 28 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index 42ec4cadb..f82f5ee21 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -144,23 +144,23 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
try:
format_function = a._format
except AttributeError:
- dtype = a.dtype.type
- if issubclass(dtype, _nt.bool):
+ dtypeobj = a.dtype.type
+ if issubclass(dtypeobj, _nt.bool):
format = "%s"
format_function = lambda x: format % x
- if issubclass(dtype, _nt.integer):
+ if issubclass(dtypeobj, _nt.integer):
max_str_len = max(len(str(max_reduce(data))),
len(str(min_reduce(data))))
format = '%' + str(max_str_len) + 'd'
format_function = lambda x: _formatInteger(x, format)
- elif issubclass(dtype, _nt.floating):
- if issubclass(dtype, _nt.longfloat):
+ elif issubclass(dtypeobj, _nt.floating):
+ if issubclass(dtypeobj, _nt.longfloat):
format_function = _longfloatFormatter(precision)
else:
format = _floatFormat(data, precision, suppress_small)
format_function = lambda x: _formatFloat(x, format)
- elif issubclass(dtype, _nt.complexfloating):
- if issubclass(dtype, _nt.clongfloat):
+ elif issubclass(dtypeobj, _nt.complexfloating):
+ if issubclass(dtypeobj, _nt.clongfloat):
format_function = _clongfloatFormatter(precision)
else:
real_format = _floatFormat(
@@ -169,13 +169,13 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
data.imag, precision, suppress_small, sign=1)
format_function = lambda x: \
_formatComplex(x, real_format, imag_format)
- elif issubclass(dtype, _nt.unicode_) or \
- issubclass(dtype, _nt.string_):
+ elif issubclass(dtypeobj, _nt.unicode_) or \
+ issubclass(dtypeobj, _nt.string_):
format = "%s"
format_function = lambda x: repr(x)
else:
format = '%s'
- format_function = lambda x: format % str(x)
+ format_function = lambda x: str(x)
next_line_prefix = " " # skip over "["
next_line_prefix += " "*len(prefix) # skip over array(
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index aef04b4c7..eaed75f14 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -6918,27 +6918,53 @@ _array_find_python_scalar_type(PyObject *op)
{
if (PyFloat_Check(op)) {
return PyArray_DescrFromType(PyArray_DOUBLE);
- } else if (PyComplex_Check(op)) {
- return PyArray_DescrFromType(PyArray_CDOUBLE);
- } else if (PyInt_Check(op)) {
- /* bools are a subclass of int */
- if (PyBool_Check(op)) {
- return PyArray_DescrFromType(PyArray_BOOL);
- } else {
- return PyArray_DescrFromType(PyArray_LONG);
- }
- } else if (PyLong_Check(op)) {
- /* if integer can fit into a longlong then return that
- */
- if ((PyLong_AsLongLong(op) == -1) && PyErr_Occurred()) {
- PyErr_Clear();
- return PyArray_DescrFromType(PyArray_OBJECT);
- }
- return PyArray_DescrFromType(PyArray_LONGLONG);
+ }
+ else if (PyComplex_Check(op)) {
+ return PyArray_DescrFromType(PyArray_CDOUBLE);
+ }
+ else if (PyInt_Check(op)) {
+ /* bools are a subclass of int */
+ if (PyBool_Check(op)) {
+ return PyArray_DescrFromType(PyArray_BOOL);
+ } else {
+ return PyArray_DescrFromType(PyArray_LONG);
+ }
+ }
+ else if (PyLong_Check(op)) {
+ /* if integer can fit into a longlong then return that
+ */
+ if ((PyLong_AsLongLong(op) == -1) && PyErr_Occurred()) {
+ PyErr_Clear();
+ return PyArray_DescrFromType(PyArray_OBJECT);
+ }
+ return PyArray_DescrFromType(PyArray_LONGLONG);
}
return NULL;
}
+static PyArray_Descr *
+_use_default_type(PyObject *op)
+{
+ int typenum, l;
+ PyObject *type;
+
+ typenum = -1;
+ l = 0;
+ type = (PyObject *)op->ob_type;
+ while (l < PyArray_NUMUSERTYPES) {
+ if (type == (PyObject *)(userdescrs[l]->typeobj)) {
+ typenum = l + PyArray_USERDEF;
+ break;
+ }
+ l++;
+ }
+ if (typenum == -1) {
+ typenum = PyArray_OBJECT;
+ }
+ return PyArray_DescrFromType(typenum);
+}
+
+
/* op is an object to be converted to an ndarray.
minitype is the minimum type-descriptor needed.
@@ -7086,7 +7112,7 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max)
deflt:
- chktype = PyArray_DescrFromType(PyArray_OBJECT);
+ chktype = _use_default_type(op);
finish: