summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-10-27 16:56:32 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-10-27 16:56:32 +0000
commit3c64ccdb050e56a0b112995d4d359c2017d4a4cd (patch)
treefc2d5e23e45da00979af647d1df309bfcceb899d /numpy/core/src
parent50c0847bc0f6dd4df616ac3c12d21dc423e8018c (diff)
downloadnumpy-3c64ccdb050e56a0b112995d4d359c2017d4a4cd.tar.gz
Allow subtypes of all array scalars and fix-up scalar_value to accept sub-types.
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/arrayobject.c1
-rw-r--r--numpy/core/src/scalartypes.inc.src101
2 files changed, 77 insertions, 25 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index cafba4e2c..65f929966 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -1432,6 +1432,7 @@ PyArray_Scalar(void *data, PyArray_Descr *descr, PyObject *base)
}
else {
destptr = scalar_value(obj, descr);
+ if (destptr == NULL) {Py_DECREF(obj); return NULL;}
}
/* copyswap for OBJECT increments the reference count */
copyswap(destptr, data, swap, base);
diff --git a/numpy/core/src/scalartypes.inc.src b/numpy/core/src/scalartypes.inc.src
index 1bea2ef7e..7ec16bf22 100644
--- a/numpy/core/src/scalartypes.inc.src
+++ b/numpy/core/src/scalartypes.inc.src
@@ -31,7 +31,7 @@ static PyTypeObject Py@NAME@ArrType_Type = {
static void *
scalar_value(PyObject *scalar, PyArray_Descr *descr)
{
- enum NPY_TYPES type_num;
+ int type_num;
if (descr == NULL) {
descr = PyArray_DescrFromScalar(scalar);
type_num = descr->type_num;
@@ -63,10 +63,59 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr)
case NPY_STRING: return (void *)PyString_AS_STRING(scalar);
case NPY_UNICODE: return (void *)PyUnicode_AS_DATA(scalar);
case NPY_VOID: return ((PyVoidScalarObject *)scalar)->obval;
- default:
- return NULL;
}
+
+ /* Must be a user-defined type --- check to see which
+ scalar it inherits from. */
+
+#define _CHK(cls) (PyObject_IsInstance(scalar, \
+ (PyObject *)&Py##cls##ArrType_Type))
+#define _OBJ(lt) &(((Py##lt##ScalarObject *)scalar)->obval)
+#define _IFCASE(cls) if _CHK(cls) return _OBJ(cls)
+
+ if _CHK(Number) {
+ if _CHK(Integer) {
+ if _CHK(SignedInteger) {
+ _IFCASE(Byte);
+ _IFCASE(Short);
+ _IFCASE(Int);
+ _IFCASE(Long);
+ _IFCASE(LongLong);
+ }
+ else { /* Unsigned Integer */
+ _IFCASE(UByte);
+ _IFCASE(UShort);
+ _IFCASE(UInt);
+ _IFCASE(ULong);
+ _IFCASE(ULongLong);
+ }
+ }
+ else { /* Inexact */
+ if _CHK(Floating) {
+ _IFCASE(Float);
+ _IFCASE(Double);
+ _IFCASE(LongDouble);
+ }
+ else { /*ComplexFloating */
+ _IFCASE(CFloat);
+ _IFCASE(CDouble);
+ _IFCASE(CLongDouble);
+ }
+ }
+ }
+ else if _CHK(Bool) return _OBJ(Bool);
+ else if _CHK(Flexible) {
+ if _CHK(String) return (void *)PyString_AS_STRING(scalar);
+ if _CHK(Unicode) return (void *)PyUnicode_AS_DATA(scalar);
+ if _CHK(Void) return ((PyVoidScalarObject *)scalar)->obval;
+ }
+ else _IFCASE(Object);
+
+ PyErr_SetString(PyExc_RuntimeError, "bad scalar");
return NULL;
+#undef _IFCASE
+#undef _OBJ
+#undef _CHK
}
/* no error checking is performed -- ctypeptr must be same type as scalar */
@@ -143,10 +192,12 @@ PyArray_CastScalarDirect(PyObject *scalar, PyArray_Descr *indescr,
void *ctypeptr, int outtype)
{
PyArray_VectorUnaryFunc* castfunc;
+ void *ptr;
castfunc = PyArray_GetCastFunc(indescr, outtype);
if (castfunc == NULL) return -1;
- castfunc(scalar_value(scalar, indescr),
- ctypeptr, 1, NULL, NULL);
+ ptr = scalar_value(scalar, indescr);
+ if (ptr == NULL) return -1;
+ castfunc(ptr, ctypeptr, 1, NULL, NULL);
return 0;
}
@@ -191,21 +242,12 @@ PyArray_FromScalar(PyObject *scalar, PyArray_Descr *outcode)
memptr = scalar_value(scalar, typecode);
-
if (memptr == NULL) {
- if (PyDataType_ISUSERDEF(typecode)) {
- /* Use setitem to set array from scalar */
- if (typecode->f->setitem(scalar,
- PyArray_DATA(r), r) < 0) {
- Py_XDECREF(outcode);
- Py_DECREF(r);
- return NULL;
- }
- }
- PyErr_SetString(PyExc_ValueError, "invalid scalar");
+ Py_XDECREF(outcode);
+ Py_DECREF(r);
return NULL;
}
-
+
#ifndef Py_UNICODE_WIDE
if (typecode->type_num == PyArray_UNICODE) {
PyUCS2Buffer_AsUCS4((Py_UNICODE *)memptr,
@@ -827,9 +869,11 @@ gentype_real_get(PyObject *self)
int typenum;
if (PyArray_IsScalar(self, ComplexFloating)) {
+ void *ptr;
typecode = _realdescr_fromcomplexscalar(self, &typenum);
- ret = PyArray_Scalar(scalar_value(self, NULL),
- typecode, NULL);
+ ptr = scalar_value(self, NULL);
+ if (ptr == NULL) {Py_DECREF(typecode); return NULL;}
+ ret = PyArray_Scalar(ptr, typecode, NULL);
Py_DECREF(typecode);
return ret;
}
@@ -851,9 +895,15 @@ gentype_imag_get(PyObject *self)
int typenum;
if (PyArray_IsScalar(self, ComplexFloating)) {
+ char *ptr;
typecode = _realdescr_fromcomplexscalar(self, &typenum);
- ret = PyArray_Scalar((char *)scalar_value(self, NULL)
- + typecode->elsize, typecode, NULL);
+ ptr = (char *)scalar_value(self, NULL);
+ if (ptr == NULL) {
+ Py_DECREF(typecode);
+ return NULL;
+ }
+ ret = PyArray_Scalar(ptr + typecode->elsize,
+ typecode, NULL);
}
else if (PyArray_IsScalar(self, Object)) {
PyObject *obj = ((PyObjectScalarObject *)self)->obval;
@@ -1102,9 +1152,12 @@ voidtype_getfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
if (PyArray_IsScalar(ret, Generic) && \
(!PyArray_IsScalar(ret, Void))) {
PyArray_Descr *new;
+ void *ptr;
if (!PyArray_ISNBO(self->descr->byteorder)) {
new = PyArray_DescrFromScalar(ret);
- byte_swap_vector(scalar_value(ret, new), 1, new->elsize);
+ ptr = scalar_value(ret, new);
+ if (ptr == NULL) {Py_DECREF(new); return NULL;}
+ byte_swap_vector(ptr, 1, new->elsize);
Py_DECREF(new);
}
}
@@ -2366,12 +2419,10 @@ ComplexFloating, Flexible, Character#
#name=bool, byte, short, int, long, longlong, ubyte, ushort, uint, ulong, ulonglong, float, double, longdouble, cfloat, cdouble, clongdouble, string, unicode, void, object#
#NAME=Bool, Byte, Short, Int, Long, LongLong, UByte, UShort, UInt, ULong, ULongLong, Float, Double, LongDouble, CFloat, CDouble, CLongDouble, String, Unicode, Void, Object#
*/
- Py@NAME@ArrType_Type.tp_flags = LEAFFLAGS;
+ Py@NAME@ArrType_Type.tp_flags = BASEFLAGS;
Py@NAME@ArrType_Type.tp_new = @name@_arrtype_new;
Py@NAME@ArrType_Type.tp_richcompare = gentype_richcompare;
/**end repeat**/
- /* Allow the Void type to be subclassed -- for adding new types */
- PyVoidArrType_Type.tp_flags = BASEFLAGS;
/**begin repeat
#name=bool, byte, short, ubyte, ushort, uint, ulong, ulonglong, float, longdouble, cfloat, clongdouble, void, object#