diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-13 10:03:13 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-13 10:03:13 +0000 |
commit | 3fa71a7122f1a9379b31a35a9f3da9b4f902299b (patch) | |
tree | 3e769252100f42448df22429162aa9abf4a1bb93 /numpy/lib/utils.py | |
parent | eee00f8f7e15592a048c8b841aef9ea81faa0fda (diff) | |
download | numpy-3fa71a7122f1a9379b31a35a9f3da9b4f902299b.tar.gz |
Improve ndpointer to allow shape and flags checking as well.
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r-- | numpy/lib/utils.py | 63 |
1 files changed, 55 insertions, 8 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 14b0d8ea3..c0701c216 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -1,8 +1,8 @@ import sys, os import inspect import types -from numpy.core.numerictypes import obj2sctype -from numpy.core.multiarray import dtype +from numpy.core.numerictypes import obj2sctype, integer +from numpy.core.multiarray import dtype, _flagdict from numpy.core import product, ndarray __all__ = ['issubclass_', 'get_numpy_include', 'issubsctype', @@ -76,30 +76,77 @@ def ctypes_load_library(libname, loader_path): libpath = os.path.join(libdir, libname) return ctypes.cdll[libpath] +def _num_fromflags(flaglist): + num = 0 + for val in flaglist: + num += _flagdict[val] + return num + +def _flags_fromnum(num): + res = [] + for key, value in _flagdict.items(): + if (num & value): + res.append(key) + return res + class _ndptr(object): def from_param(cls, obj): if not isinstance(obj, ndarray): raise TypeError("argument must be an ndarray") if obj.dtype != cls._dtype_: raise TypeError("array must have data type", cls._dtype_) + if cls._ndim_ and obj.ndim != cls._ndim_: + raise TypeError("array must have %d dimension(s)" % cls._ndim_) + if cls._shape_ and obj.shape != cls._shape_: + raise TypeError("array must have shape ", cls._shape_) + if cls._flags_ and ((obj.flags.num & cls._flags_) != cls._flags_): + raise TypeError("array must have flags ", + _flags_fromnum(cls._flags_)) return obj.ctypes from_param = classmethod(from_param) -# Factory for a type-checking object with from_param defined +# Factory for an array-checking object with from_param defined _pointer_type_cache = {} -def ndpointer(datatype): +def ndpointer(datatype, ndim=None, shape=None, flags=None): datatype = dtype(datatype) + num = None + if flags is not None: + if isinstance(flags, str): + flags = flags.split(',') + elif isinstance(flags, (int, integer)): + num = flags + flags = _flags_fromnum(flags) + if num is None: + flags = [x.strip().upper() for x in flags] + num = _num_fromflags(flags) try: - return _pointer_type_cache[datatype] + return _pointer_type_cache[(datatype, ndim, shape, num)] except KeyError: - pass + pass if datatype.names: name = str(id(datatype)) else: name = datatype.str + if ndim is not None: + name += "_%dd" % ndim + if shape is not None: + try: + strshape = [str(x) for x in shape] + except TypeError: + strshape = [str(shape)] + shape = (shape,) + shape = tuple(shape) + name += "_"+"x".join(strshape) + if flags is not None: + name += "_"+"_".join(flags) + else: + flags = [] klass = type("ndpointer_%s"%name, (_ndptr,), - {"_dtype_": datatype}) - _pointer_type_cache[datatype] = klass + {"_dtype_": datatype, + "_shape_" : shape, + "_ndim_" : ndim, + "_flags_" : num}) + _pointer_type_cache[datatype] = klass return klass |