diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2011-03-22 14:32:11 -0700 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-03-25 15:24:21 -0700 |
commit | 88e8c15841538acb6fe42ebdbeaa93f57af6b27f (patch) | |
tree | 8a36609c53dbe22d23a43d5ba759ea71e5182481 /numpy | |
parent | 73ea5e7be76d77b6ac67b772b8770b0b12d67722 (diff) | |
download | numpy-88e8c15841538acb6fe42ebdbeaa93f57af6b27f.tar.gz |
ENH: Add scalar support for the format() function introduced in Python 2.6 (#1675)
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 80 | ||||
-rw-r--r-- | numpy/core/tests/test_print.py | 26 |
2 files changed, 106 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index eff71d12d..5fecab5ac 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -361,6 +361,79 @@ gentype_repr(PyObject *self) return ret; } +#if PY_VERSION_HEX >= 0x02060000 +/* + * The __format__ method for PEP 3101. + */ +static PyObject * +gentype_format(PyObject *self, PyObject *args) +{ + PyObject *format_spec; + PyObject *obj, *ret; + +#if defined(NPY_PY3K) + if (!PyArg_ParseTuple(args, "U:__format__", &format_spec)) { + return NULL; + } +#else + if (!PyArg_ParseTuple(args, "O:__format__", &format_spec)) { + return NULL; + } + + if (!PyUnicode_Check(format_spec) && !PyString_Check(format_spec)) { + PyErr_SetString(PyExc_TypeError, + "format must be a string"); + return NULL; + } +#endif + + /* + * Convert to an appropriate Python type and call its format. + * TODO: For some types, like long double, this isn't right, + * because it throws away precision. + */ + if (Py_TYPE(self) == &PyBoolArrType_Type) { + obj = PyBool_FromLong(((PyBoolScalarObject *)self)->obval); + } + else if (PyArray_IsScalar(self, Integer)) { +#if defined(NPY_PY3K) + obj = Py_TYPE(self)->tp_as_number->nb_int(self); +#else + obj = Py_TYPE(self)->tp_as_number->nb_long(self); +#endif + } + else if (PyArray_IsScalar(self, Floating)) { + obj = Py_TYPE(self)->tp_as_number->nb_float(self); + } + else if (PyArray_IsScalar(self, ComplexFloating)) { + double val[2]; + PyArray_Descr *dtype = PyArray_DescrFromScalar(self); + + if (dtype == NULL) { + return NULL; + } + if (PyArray_CastScalarDirect(self, dtype, &val[0], NPY_CDOUBLE) < 0) { + Py_DECREF(dtype); + return NULL; + } + obj = PyComplex_FromDoubles(val[0], val[1]); + Py_DECREF(dtype); + } + else { + obj = self; + Py_INCREF(obj); + } + + if (obj == NULL) { + return NULL; + } + + ret = PyObject_Format(obj, format_spec); + Py_DECREF(obj); + return ret; +} +#endif + #ifdef FORCE_NO_LONG_DOUBLE_FORMATTING #undef NPY_LONGDOUBLE_FMT #define NPY_LONGDOUBLE_FMT NPY_DOUBLE_FMT @@ -1684,6 +1757,13 @@ static PyMethodDef gentype_methods[] = { (PyCFunction)gentype_round, METH_VARARGS | METH_KEYWORDS, NULL}, #endif +#if PY_VERSION_HEX >= 0x02060000 + /* For the format function */ + {"__format__", + gentype_format, + METH_VARARGS, + "NumPy array scalar formatter"}, +#endif {"setflags", (PyCFunction)gentype_setflags, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/numpy/core/tests/test_print.py b/numpy/core/tests/test_print.py index d83f21cb2..16e772a8e 100644 --- a/numpy/core/tests/test_print.py +++ b/numpy/core/tests/test_print.py @@ -198,6 +198,32 @@ def test_complex_type_print(): for t in [np.complex64, np.cdouble, np.clongdouble] : yield check_complex_type_print, t +@dec.skipif(sys.version_info < (2,6)) +def test_scalar_format(): + """Test the str.format method with NumPy scalar types""" + tests = [('{0}', True, np.bool_), + ('{0}', False, np.bool_), + ('{0:d}', 130, np.uint8), + ('{0:d}', 50000, np.uint16), + ('{0:d}', 3000000000, np.uint32), + ('{0:d}', 15000000000000000000, np.uint64), + ('{0:d}', -120, np.int8), + ('{0:d}', -30000, np.int16), + ('{0:d}', -2000000000, np.int32), + ('{0:d}', -7000000000000000000, np.int64), + ('{0:g}', 1.5, np.float16), + ('{0:g}', 1.5, np.float32), + ('{0:g}', 1.5, np.float64), + ('{0:g}', 1.5, np.longdouble), + ('{0:g}', 1.5+0.5j, np.complex64), + ('{0:g}', 1.5+0.5j, np.complex128), + ('{0:g}', 1.5+0.5j, np.clongdouble)] + + for (fmat, val, valtype) in tests: + assert_equal(fmat.format(val), fmat.format(valtype(val)), + "failed with val %s, type %s" % (val, valtype)) + + # Locale tests: scalar types formatting should be independent of the locale def in_foreign_locale(func): """ |