diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-10-27 16:56:32 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-10-27 16:56:32 +0000 |
commit | 3c64ccdb050e56a0b112995d4d359c2017d4a4cd (patch) | |
tree | fc2d5e23e45da00979af647d1df309bfcceb899d /numpy/core/src | |
parent | 50c0847bc0f6dd4df616ac3c12d21dc423e8018c (diff) | |
download | numpy-3c64ccdb050e56a0b112995d4d359c2017d4a4cd.tar.gz |
Allow subtypes of all array scalars and fix-up scalar_value to accept sub-types.
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/arrayobject.c | 1 | ||||
-rw-r--r-- | numpy/core/src/scalartypes.inc.src | 101 |
2 files changed, 77 insertions, 25 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index cafba4e2c..65f929966 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -1432,6 +1432,7 @@ PyArray_Scalar(void *data, PyArray_Descr *descr, PyObject *base) } else { destptr = scalar_value(obj, descr); + if (destptr == NULL) {Py_DECREF(obj); return NULL;} } /* copyswap for OBJECT increments the reference count */ copyswap(destptr, data, swap, base); diff --git a/numpy/core/src/scalartypes.inc.src b/numpy/core/src/scalartypes.inc.src index 1bea2ef7e..7ec16bf22 100644 --- a/numpy/core/src/scalartypes.inc.src +++ b/numpy/core/src/scalartypes.inc.src @@ -31,7 +31,7 @@ static PyTypeObject Py@NAME@ArrType_Type = { static void * scalar_value(PyObject *scalar, PyArray_Descr *descr) { - enum NPY_TYPES type_num; + int type_num; if (descr == NULL) { descr = PyArray_DescrFromScalar(scalar); type_num = descr->type_num; @@ -63,10 +63,59 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr) case NPY_STRING: return (void *)PyString_AS_STRING(scalar); case NPY_UNICODE: return (void *)PyUnicode_AS_DATA(scalar); case NPY_VOID: return ((PyVoidScalarObject *)scalar)->obval; - default: - return NULL; } + + /* Must be a user-defined type --- check to see which + scalar it inherits from. */ + +#define _CHK(cls) (PyObject_IsInstance(scalar, \ + (PyObject *)&Py##cls##ArrType_Type)) +#define _OBJ(lt) &(((Py##lt##ScalarObject *)scalar)->obval) +#define _IFCASE(cls) if _CHK(cls) return _OBJ(cls) + + if _CHK(Number) { + if _CHK(Integer) { + if _CHK(SignedInteger) { + _IFCASE(Byte); + _IFCASE(Short); + _IFCASE(Int); + _IFCASE(Long); + _IFCASE(LongLong); + } + else { /* Unsigned Integer */ + _IFCASE(UByte); + _IFCASE(UShort); + _IFCASE(UInt); + _IFCASE(ULong); + _IFCASE(ULongLong); + } + } + else { /* Inexact */ + if _CHK(Floating) { + _IFCASE(Float); + _IFCASE(Double); + _IFCASE(LongDouble); + } + else { /*ComplexFloating */ + _IFCASE(CFloat); + _IFCASE(CDouble); + _IFCASE(CLongDouble); + } + } + } + else if _CHK(Bool) return _OBJ(Bool); + else if _CHK(Flexible) { + if _CHK(String) return (void *)PyString_AS_STRING(scalar); + if _CHK(Unicode) return (void *)PyUnicode_AS_DATA(scalar); + if _CHK(Void) return ((PyVoidScalarObject *)scalar)->obval; + } + else _IFCASE(Object); + + PyErr_SetString(PyExc_RuntimeError, "bad scalar"); return NULL; +#undef _IFCASE +#undef _OBJ +#undef _CHK } /* no error checking is performed -- ctypeptr must be same type as scalar */ @@ -143,10 +192,12 @@ PyArray_CastScalarDirect(PyObject *scalar, PyArray_Descr *indescr, void *ctypeptr, int outtype) { PyArray_VectorUnaryFunc* castfunc; + void *ptr; castfunc = PyArray_GetCastFunc(indescr, outtype); if (castfunc == NULL) return -1; - castfunc(scalar_value(scalar, indescr), - ctypeptr, 1, NULL, NULL); + ptr = scalar_value(scalar, indescr); + if (ptr == NULL) return -1; + castfunc(ptr, ctypeptr, 1, NULL, NULL); return 0; } @@ -191,21 +242,12 @@ PyArray_FromScalar(PyObject *scalar, PyArray_Descr *outcode) memptr = scalar_value(scalar, typecode); - if (memptr == NULL) { - if (PyDataType_ISUSERDEF(typecode)) { - /* Use setitem to set array from scalar */ - if (typecode->f->setitem(scalar, - PyArray_DATA(r), r) < 0) { - Py_XDECREF(outcode); - Py_DECREF(r); - return NULL; - } - } - PyErr_SetString(PyExc_ValueError, "invalid scalar"); + Py_XDECREF(outcode); + Py_DECREF(r); return NULL; } - + #ifndef Py_UNICODE_WIDE if (typecode->type_num == PyArray_UNICODE) { PyUCS2Buffer_AsUCS4((Py_UNICODE *)memptr, @@ -827,9 +869,11 @@ gentype_real_get(PyObject *self) int typenum; if (PyArray_IsScalar(self, ComplexFloating)) { + void *ptr; typecode = _realdescr_fromcomplexscalar(self, &typenum); - ret = PyArray_Scalar(scalar_value(self, NULL), - typecode, NULL); + ptr = scalar_value(self, NULL); + if (ptr == NULL) {Py_DECREF(typecode); return NULL;} + ret = PyArray_Scalar(ptr, typecode, NULL); Py_DECREF(typecode); return ret; } @@ -851,9 +895,15 @@ gentype_imag_get(PyObject *self) int typenum; if (PyArray_IsScalar(self, ComplexFloating)) { + char *ptr; typecode = _realdescr_fromcomplexscalar(self, &typenum); - ret = PyArray_Scalar((char *)scalar_value(self, NULL) - + typecode->elsize, typecode, NULL); + ptr = (char *)scalar_value(self, NULL); + if (ptr == NULL) { + Py_DECREF(typecode); + return NULL; + } + ret = PyArray_Scalar(ptr + typecode->elsize, + typecode, NULL); } else if (PyArray_IsScalar(self, Object)) { PyObject *obj = ((PyObjectScalarObject *)self)->obval; @@ -1102,9 +1152,12 @@ voidtype_getfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds) if (PyArray_IsScalar(ret, Generic) && \ (!PyArray_IsScalar(ret, Void))) { PyArray_Descr *new; + void *ptr; if (!PyArray_ISNBO(self->descr->byteorder)) { new = PyArray_DescrFromScalar(ret); - byte_swap_vector(scalar_value(ret, new), 1, new->elsize); + ptr = scalar_value(ret, new); + if (ptr == NULL) {Py_DECREF(new); return NULL;} + byte_swap_vector(ptr, 1, new->elsize); Py_DECREF(new); } } @@ -2366,12 +2419,10 @@ ComplexFloating, Flexible, Character# #name=bool, byte, short, int, long, longlong, ubyte, ushort, uint, ulong, ulonglong, float, double, longdouble, cfloat, cdouble, clongdouble, string, unicode, void, object# #NAME=Bool, Byte, Short, Int, Long, LongLong, UByte, UShort, UInt, ULong, ULongLong, Float, Double, LongDouble, CFloat, CDouble, CLongDouble, String, Unicode, Void, Object# */ - Py@NAME@ArrType_Type.tp_flags = LEAFFLAGS; + Py@NAME@ArrType_Type.tp_flags = BASEFLAGS; Py@NAME@ArrType_Type.tp_new = @name@_arrtype_new; Py@NAME@ArrType_Type.tp_richcompare = gentype_richcompare; /**end repeat**/ - /* Allow the Void type to be subclassed -- for adding new types */ - PyVoidArrType_Type.tp_flags = BASEFLAGS; /**begin repeat #name=bool, byte, short, ubyte, ushort, uint, ulong, ulonglong, float, longdouble, cfloat, clongdouble, void, object# |