diff options
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/arrayobject.c | 23 | ||||
-rw-r--r-- | numpy/core/src/scalartypes.inc.src | 58 |
2 files changed, 67 insertions, 14 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index a7d0b67b4..75c1b6e29 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2490,7 +2490,7 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op) return 0; } -int +static int count_new_axes_0d(PyObject *tuple) { int i, argument_count; @@ -2706,14 +2706,12 @@ array_subscript(PyArrayObject *self, PyObject *op) return NULL; } + if (op == Py_Ellipsis) { + Py_INCREF(self); + return (PyObject *)self; + } + if (self->nd == 0) { - if (op == Py_Ellipsis) { - /* XXX: This leads to a small inconsistency - XXX: with the nd>0 case where (x[...] is x) - XXX: is false for nd>0 case. */ - Py_INCREF(self); - return (PyObject *)self; - } if (op == Py_None) return add_new_axes_0d(self, 1); if (PyTuple_Check(op)) { @@ -2726,9 +2724,8 @@ array_subscript(PyArrayObject *self, PyObject *op) return add_new_axes_0d(self, nd); } /* Allow Boolean mask selection also */ - if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) || - (PyArray_Check(op) && (PyArray_DIMS(op)==0) && - PyArray_ISBOOL(op))) { + if ((PyArray_Check(op) && (PyArray_DIMS(op)==0) && + PyArray_ISBOOL(op))) { if (PyObject_IsTrue(op)) { Py_INCREF(self); return (PyObject *)self; @@ -3051,7 +3048,8 @@ array_subscript_nice(PyArrayObject *self, PyObject *op) if ((op == Py_Ellipsis) || PyString_Check(op) || PyUnicode_Check(op)) noellipses = FALSE; else if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) || - (PyArray_Check(op) && (PyArray_DIMS(op)==0))) + (PyArray_Check(op) && (PyArray_DIMS(op)==0) && + PyArray_ISBOOL(op))) noellipses = FALSE; else if (PySequence_Check(op)) { int n, i; @@ -9212,7 +9210,6 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind) char *dptr; int size; PyObject *obj = NULL; - int swap; PyArray_CopySwapFunc *copyswap; if (ind == Py_Ellipsis) { diff --git a/numpy/core/src/scalartypes.inc.src b/numpy/core/src/scalartypes.inc.src index 5e74cb0d6..2ae855400 100644 --- a/numpy/core/src/scalartypes.inc.src +++ b/numpy/core/src/scalartypes.inc.src @@ -2373,6 +2373,54 @@ static PyTypeObject PyObjectArrType_Type = { 0, /* tp_flags */ }; + +static PyObject * +add_new_axes_0d(PyArrayObject *, int); + +static int +count_new_axes_0d(PyObject *); + +static PyObject * +gen_arrtype_subscript(PyObject *self, PyObject *key) +{ + /* Only [...], [...,<???>], [<???>, ...], + is allowed for indexing a scalar + + These return a new N-d array with a copy of + the data where N is the number of None's in <???>. + + */ + PyObject *res, *ret; + int N; + + if (key == Py_Ellipsis || key == Py_None || + PyTuple_Check(key)) { + res = PyArray_FromScalar(self, NULL); + } + else { + PyErr_SetString(PyExc_IndexError, + "invalid index to scalar variable."); + return NULL; + } + + if (key == Py_Ellipsis) + return res; + + if (key == Py_None) { + ret = add_new_axes_0d((PyArrayObject *)res, 1); + Py_DECREF(res); + return ret; + } + /* Must be a Tuple */ + + N = count_new_axes_0d(key); + if (N < 0) return NULL; + ret = add_new_axes_0d((PyArrayObject *)res, N); + Py_DECREF(res); + return ret; +} + + /**begin repeat #name=bool, string, unicode, void# #NAME=Bool, String, Unicode, Void# @@ -2418,6 +2466,14 @@ static PyTypeObject Py@NAME@ArrType_Type = { #undef _THIS_SIZE /**end repeat**/ + +static PyMappingMethods gentype_as_mapping = { + NULL, + (binaryfunc)gen_arrtype_subscript, + NULL +}; + + /**begin repeat #NAME=CFloat, CDouble, CLongDouble# #name=complex*3# @@ -2475,7 +2531,6 @@ static PyTypeObject Py@NAME@ArrType_Type = { /**end repeat**/ - static PyNumberMethods longdoubletype_as_number; static PyNumberMethods clongdoubletype_as_number; @@ -2486,6 +2541,7 @@ initialize_numeric_types(void) PyGenericArrType_Type.tp_dealloc = (destructor)gentype_dealloc; PyGenericArrType_Type.tp_as_number = &gentype_as_number; PyGenericArrType_Type.tp_as_buffer = &gentype_as_buffer; + PyGenericArrType_Type.tp_as_mapping = &gentype_as_mapping; PyGenericArrType_Type.tp_flags = BASEFLAGS; PyGenericArrType_Type.tp_methods = gentype_methods; PyGenericArrType_Type.tp_getset = gentype_getsets; |