summaryrefslogtreecommitdiff
path: root/numpy/lib/utils.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-13 10:03:13 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-13 10:03:13 +0000
commit3fa71a7122f1a9379b31a35a9f3da9b4f902299b (patch)
tree3e769252100f42448df22429162aa9abf4a1bb93 /numpy/lib/utils.py
parenteee00f8f7e15592a048c8b841aef9ea81faa0fda (diff)
downloadnumpy-3fa71a7122f1a9379b31a35a9f3da9b4f902299b.tar.gz
Improve ndpointer to allow shape and flags checking as well.
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r--numpy/lib/utils.py63
1 files changed, 55 insertions, 8 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 14b0d8ea3..c0701c216 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -1,8 +1,8 @@
import sys, os
import inspect
import types
-from numpy.core.numerictypes import obj2sctype
-from numpy.core.multiarray import dtype
+from numpy.core.numerictypes import obj2sctype, integer
+from numpy.core.multiarray import dtype, _flagdict
from numpy.core import product, ndarray
__all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
@@ -76,30 +76,77 @@ def ctypes_load_library(libname, loader_path):
libpath = os.path.join(libdir, libname)
return ctypes.cdll[libpath]
+def _num_fromflags(flaglist):
+ num = 0
+ for val in flaglist:
+ num += _flagdict[val]
+ return num
+
+def _flags_fromnum(num):
+ res = []
+ for key, value in _flagdict.items():
+ if (num & value):
+ res.append(key)
+ return res
+
class _ndptr(object):
def from_param(cls, obj):
if not isinstance(obj, ndarray):
raise TypeError("argument must be an ndarray")
if obj.dtype != cls._dtype_:
raise TypeError("array must have data type", cls._dtype_)
+ if cls._ndim_ and obj.ndim != cls._ndim_:
+ raise TypeError("array must have %d dimension(s)" % cls._ndim_)
+ if cls._shape_ and obj.shape != cls._shape_:
+ raise TypeError("array must have shape ", cls._shape_)
+ if cls._flags_ and ((obj.flags.num & cls._flags_) != cls._flags_):
+ raise TypeError("array must have flags ",
+ _flags_fromnum(cls._flags_))
return obj.ctypes
from_param = classmethod(from_param)
-# Factory for a type-checking object with from_param defined
+# Factory for an array-checking object with from_param defined
_pointer_type_cache = {}
-def ndpointer(datatype):
+def ndpointer(datatype, ndim=None, shape=None, flags=None):
datatype = dtype(datatype)
+ num = None
+ if flags is not None:
+ if isinstance(flags, str):
+ flags = flags.split(',')
+ elif isinstance(flags, (int, integer)):
+ num = flags
+ flags = _flags_fromnum(flags)
+ if num is None:
+ flags = [x.strip().upper() for x in flags]
+ num = _num_fromflags(flags)
try:
- return _pointer_type_cache[datatype]
+ return _pointer_type_cache[(datatype, ndim, shape, num)]
except KeyError:
- pass
+ pass
if datatype.names:
name = str(id(datatype))
else:
name = datatype.str
+ if ndim is not None:
+ name += "_%dd" % ndim
+ if shape is not None:
+ try:
+ strshape = [str(x) for x in shape]
+ except TypeError:
+ strshape = [str(shape)]
+ shape = (shape,)
+ shape = tuple(shape)
+ name += "_"+"x".join(strshape)
+ if flags is not None:
+ name += "_"+"_".join(flags)
+ else:
+ flags = []
klass = type("ndpointer_%s"%name, (_ndptr,),
- {"_dtype_": datatype})
- _pointer_type_cache[datatype] = klass
+ {"_dtype_": datatype,
+ "_shape_" : shape,
+ "_ndim_" : ndim,
+ "_flags_" : num})
+ _pointer_type_cache[datatype] = klass
return klass