diff options
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r-- | numpy/lib/utils.py | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 8306b799c..14b0d8ea3 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -7,8 +7,8 @@ from numpy.core import product, ndarray __all__ = ['issubclass_', 'get_numpy_include', 'issubsctype', 'issubdtype', 'deprecate', 'get_numarray_include', - 'get_include', 'ctypes_load_library', 'info', - 'source', 'who'] + 'get_include', 'ctypes_load_library', 'ndpointer', + 'info', 'source', 'who'] def issubclass_(arg1, arg2): try: @@ -76,6 +76,32 @@ def ctypes_load_library(libname, loader_path): libpath = os.path.join(libdir, libname) return ctypes.cdll[libpath] +class _ndptr(object): + def from_param(cls, obj): + if not isinstance(obj, ndarray): + raise TypeError("argument must be an ndarray") + if obj.dtype != cls._dtype_: + raise TypeError("array must have data type", cls._dtype_) + return obj.ctypes + from_param = classmethod(from_param) + +# Factory for a type-checking object with from_param defined +_pointer_type_cache = {} +def ndpointer(datatype): + datatype = dtype(datatype) + try: + return _pointer_type_cache[datatype] + except KeyError: + pass + if datatype.names: + name = str(id(datatype)) + else: + name = datatype.str + klass = type("ndpointer_%s"%name, (_ndptr,), + {"_dtype_": datatype}) + _pointer_type_cache[datatype] = klass + return klass + if sys.version_info < (2, 4): # Can't set __name__ in 2.3 |