summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2008-02-25 22:51:21 +0000
committerTravis Oliphant <oliphant@enthought.com>2008-02-25 22:51:21 +0000
commitfa547b80f7035da85f66f9cbabc4ff75969d23cd (patch)
tree321b5ef4a11a64d57afd4f9f69c2c674949dd7fb /numpy/core/src
parent5089f5d7714fc8cebaaaf302f4c78cdba739e032 (diff)
downloadnumpy-fa547b80f7035da85f66f9cbabc4ff75969d23cd.tar.gz
Allow numpy scalars to be indexed in limited ways, but not be iterable. Fix consistency bug with [...] indexing and remove useless check and allow 0-d boolean arrays to work as masks for scalars.
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/arrayobject.c23
-rw-r--r--numpy/core/src/scalartypes.inc.src58
2 files changed, 67 insertions, 14 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index a7d0b67b4..75c1b6e29 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2490,7 +2490,7 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op)
return 0;
}
-int
+static int
count_new_axes_0d(PyObject *tuple)
{
int i, argument_count;
@@ -2706,14 +2706,12 @@ array_subscript(PyArrayObject *self, PyObject *op)
return NULL;
}
+ if (op == Py_Ellipsis) {
+ Py_INCREF(self);
+ return (PyObject *)self;
+ }
+
if (self->nd == 0) {
- if (op == Py_Ellipsis) {
- /* XXX: This leads to a small inconsistency
- XXX: with the nd>0 case where (x[...] is x)
- XXX: is false for nd>0 case. */
- Py_INCREF(self);
- return (PyObject *)self;
- }
if (op == Py_None)
return add_new_axes_0d(self, 1);
if (PyTuple_Check(op)) {
@@ -2726,9 +2724,8 @@ array_subscript(PyArrayObject *self, PyObject *op)
return add_new_axes_0d(self, nd);
}
/* Allow Boolean mask selection also */
- if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) ||
- (PyArray_Check(op) && (PyArray_DIMS(op)==0) &&
- PyArray_ISBOOL(op))) {
+ if ((PyArray_Check(op) && (PyArray_DIMS(op)==0) &&
+ PyArray_ISBOOL(op))) {
if (PyObject_IsTrue(op)) {
Py_INCREF(self);
return (PyObject *)self;
@@ -3051,7 +3048,8 @@ array_subscript_nice(PyArrayObject *self, PyObject *op)
if ((op == Py_Ellipsis) || PyString_Check(op) || PyUnicode_Check(op))
noellipses = FALSE;
else if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) ||
- (PyArray_Check(op) && (PyArray_DIMS(op)==0)))
+ (PyArray_Check(op) && (PyArray_DIMS(op)==0) &&
+ PyArray_ISBOOL(op)))
noellipses = FALSE;
else if (PySequence_Check(op)) {
int n, i;
@@ -9212,7 +9210,6 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)
char *dptr;
int size;
PyObject *obj = NULL;
- int swap;
PyArray_CopySwapFunc *copyswap;
if (ind == Py_Ellipsis) {
diff --git a/numpy/core/src/scalartypes.inc.src b/numpy/core/src/scalartypes.inc.src
index 5e74cb0d6..2ae855400 100644
--- a/numpy/core/src/scalartypes.inc.src
+++ b/numpy/core/src/scalartypes.inc.src
@@ -2373,6 +2373,54 @@ static PyTypeObject PyObjectArrType_Type = {
0, /* tp_flags */
};
+
+static PyObject *
+add_new_axes_0d(PyArrayObject *, int);
+
+static int
+count_new_axes_0d(PyObject *);
+
+static PyObject *
+gen_arrtype_subscript(PyObject *self, PyObject *key)
+{
+ /* Only [...], [...,<???>], [<???>, ...],
+ is allowed for indexing a scalar
+
+ These return a new N-d array with a copy of
+ the data where N is the number of None's in <???>.
+
+ */
+ PyObject *res, *ret;
+ int N;
+
+ if (key == Py_Ellipsis || key == Py_None ||
+ PyTuple_Check(key)) {
+ res = PyArray_FromScalar(self, NULL);
+ }
+ else {
+ PyErr_SetString(PyExc_IndexError,
+ "invalid index to scalar variable.");
+ return NULL;
+ }
+
+ if (key == Py_Ellipsis)
+ return res;
+
+ if (key == Py_None) {
+ ret = add_new_axes_0d((PyArrayObject *)res, 1);
+ Py_DECREF(res);
+ return ret;
+ }
+ /* Must be a Tuple */
+
+ N = count_new_axes_0d(key);
+ if (N < 0) return NULL;
+ ret = add_new_axes_0d((PyArrayObject *)res, N);
+ Py_DECREF(res);
+ return ret;
+}
+
+
/**begin repeat
#name=bool, string, unicode, void#
#NAME=Bool, String, Unicode, Void#
@@ -2418,6 +2466,14 @@ static PyTypeObject Py@NAME@ArrType_Type = {
#undef _THIS_SIZE
/**end repeat**/
+
+static PyMappingMethods gentype_as_mapping = {
+ NULL,
+ (binaryfunc)gen_arrtype_subscript,
+ NULL
+};
+
+
/**begin repeat
#NAME=CFloat, CDouble, CLongDouble#
#name=complex*3#
@@ -2475,7 +2531,6 @@ static PyTypeObject Py@NAME@ArrType_Type = {
/**end repeat**/
-
static PyNumberMethods longdoubletype_as_number;
static PyNumberMethods clongdoubletype_as_number;
@@ -2486,6 +2541,7 @@ initialize_numeric_types(void)
PyGenericArrType_Type.tp_dealloc = (destructor)gentype_dealloc;
PyGenericArrType_Type.tp_as_number = &gentype_as_number;
PyGenericArrType_Type.tp_as_buffer = &gentype_as_buffer;
+ PyGenericArrType_Type.tp_as_mapping = &gentype_as_mapping;
PyGenericArrType_Type.tp_flags = BASEFLAGS;
PyGenericArrType_Type.tp_methods = gentype_methods;
PyGenericArrType_Type.tp_getset = gentype_getsets;