summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-03-22 14:32:11 -0700
committerMark Wiebe <mwwiebe@gmail.com>2011-03-25 15:24:21 -0700
commit88e8c15841538acb6fe42ebdbeaa93f57af6b27f (patch)
tree8a36609c53dbe22d23a43d5ba759ea71e5182481 /numpy
parent73ea5e7be76d77b6ac67b772b8770b0b12d67722 (diff)
downloadnumpy-88e8c15841538acb6fe42ebdbeaa93f57af6b27f.tar.gz
ENH: Add scalar support for the format() function introduced in Python 2.6 (#1675)
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src80
-rw-r--r--numpy/core/tests/test_print.py26
2 files changed, 106 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index eff71d12d..5fecab5ac 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -361,6 +361,79 @@ gentype_repr(PyObject *self)
return ret;
}
+#if PY_VERSION_HEX >= 0x02060000
+/*
+ * The __format__ method for PEP 3101.
+ */
+static PyObject *
+gentype_format(PyObject *self, PyObject *args)
+{
+ PyObject *format_spec;
+ PyObject *obj, *ret;
+
+#if defined(NPY_PY3K)
+ if (!PyArg_ParseTuple(args, "U:__format__", &format_spec)) {
+ return NULL;
+ }
+#else
+ if (!PyArg_ParseTuple(args, "O:__format__", &format_spec)) {
+ return NULL;
+ }
+
+ if (!PyUnicode_Check(format_spec) && !PyString_Check(format_spec)) {
+ PyErr_SetString(PyExc_TypeError,
+ "format must be a string");
+ return NULL;
+ }
+#endif
+
+ /*
+ * Convert to an appropriate Python type and call its format.
+ * TODO: For some types, like long double, this isn't right,
+ * because it throws away precision.
+ */
+ if (Py_TYPE(self) == &PyBoolArrType_Type) {
+ obj = PyBool_FromLong(((PyBoolScalarObject *)self)->obval);
+ }
+ else if (PyArray_IsScalar(self, Integer)) {
+#if defined(NPY_PY3K)
+ obj = Py_TYPE(self)->tp_as_number->nb_int(self);
+#else
+ obj = Py_TYPE(self)->tp_as_number->nb_long(self);
+#endif
+ }
+ else if (PyArray_IsScalar(self, Floating)) {
+ obj = Py_TYPE(self)->tp_as_number->nb_float(self);
+ }
+ else if (PyArray_IsScalar(self, ComplexFloating)) {
+ double val[2];
+ PyArray_Descr *dtype = PyArray_DescrFromScalar(self);
+
+ if (dtype == NULL) {
+ return NULL;
+ }
+ if (PyArray_CastScalarDirect(self, dtype, &val[0], NPY_CDOUBLE) < 0) {
+ Py_DECREF(dtype);
+ return NULL;
+ }
+ obj = PyComplex_FromDoubles(val[0], val[1]);
+ Py_DECREF(dtype);
+ }
+ else {
+ obj = self;
+ Py_INCREF(obj);
+ }
+
+ if (obj == NULL) {
+ return NULL;
+ }
+
+ ret = PyObject_Format(obj, format_spec);
+ Py_DECREF(obj);
+ return ret;
+}
+#endif
+
#ifdef FORCE_NO_LONG_DOUBLE_FORMATTING
#undef NPY_LONGDOUBLE_FMT
#define NPY_LONGDOUBLE_FMT NPY_DOUBLE_FMT
@@ -1684,6 +1757,13 @@ static PyMethodDef gentype_methods[] = {
(PyCFunction)gentype_round,
METH_VARARGS | METH_KEYWORDS, NULL},
#endif
+#if PY_VERSION_HEX >= 0x02060000
+ /* For the format function */
+ {"__format__",
+ gentype_format,
+ METH_VARARGS,
+ "NumPy array scalar formatter"},
+#endif
{"setflags",
(PyCFunction)gentype_setflags,
METH_VARARGS | METH_KEYWORDS, NULL},
diff --git a/numpy/core/tests/test_print.py b/numpy/core/tests/test_print.py
index d83f21cb2..16e772a8e 100644
--- a/numpy/core/tests/test_print.py
+++ b/numpy/core/tests/test_print.py
@@ -198,6 +198,32 @@ def test_complex_type_print():
for t in [np.complex64, np.cdouble, np.clongdouble] :
yield check_complex_type_print, t
+@dec.skipif(sys.version_info < (2,6))
+def test_scalar_format():
+ """Test the str.format method with NumPy scalar types"""
+ tests = [('{0}', True, np.bool_),
+ ('{0}', False, np.bool_),
+ ('{0:d}', 130, np.uint8),
+ ('{0:d}', 50000, np.uint16),
+ ('{0:d}', 3000000000, np.uint32),
+ ('{0:d}', 15000000000000000000, np.uint64),
+ ('{0:d}', -120, np.int8),
+ ('{0:d}', -30000, np.int16),
+ ('{0:d}', -2000000000, np.int32),
+ ('{0:d}', -7000000000000000000, np.int64),
+ ('{0:g}', 1.5, np.float16),
+ ('{0:g}', 1.5, np.float32),
+ ('{0:g}', 1.5, np.float64),
+ ('{0:g}', 1.5, np.longdouble),
+ ('{0:g}', 1.5+0.5j, np.complex64),
+ ('{0:g}', 1.5+0.5j, np.complex128),
+ ('{0:g}', 1.5+0.5j, np.clongdouble)]
+
+ for (fmat, val, valtype) in tests:
+ assert_equal(fmat.format(val), fmat.format(valtype(val)),
+ "failed with val %s, type %s" % (val, valtype))
+
+
# Locale tests: scalar types formatting should be independent of the locale
def in_foreign_locale(func):
"""