summaryrefslogtreecommitdiff
path: root/numpy/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r--numpy/lib/utils.py30
1 files changed, 28 insertions, 2 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 8306b799c..14b0d8ea3 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -7,8 +7,8 @@ from numpy.core import product, ndarray
__all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
'issubdtype', 'deprecate', 'get_numarray_include',
- 'get_include', 'ctypes_load_library', 'info',
- 'source', 'who']
+ 'get_include', 'ctypes_load_library', 'ndpointer',
+ 'info', 'source', 'who']
def issubclass_(arg1, arg2):
try:
@@ -76,6 +76,32 @@ def ctypes_load_library(libname, loader_path):
libpath = os.path.join(libdir, libname)
return ctypes.cdll[libpath]
+class _ndptr(object):
+ def from_param(cls, obj):
+ if not isinstance(obj, ndarray):
+ raise TypeError("argument must be an ndarray")
+ if obj.dtype != cls._dtype_:
+ raise TypeError("array must have data type", cls._dtype_)
+ return obj.ctypes
+ from_param = classmethod(from_param)
+
+# Factory for a type-checking object with from_param defined
+_pointer_type_cache = {}
+def ndpointer(datatype):
+ datatype = dtype(datatype)
+ try:
+ return _pointer_type_cache[datatype]
+ except KeyError:
+ pass
+ if datatype.names:
+ name = str(id(datatype))
+ else:
+ name = datatype.str
+ klass = type("ndpointer_%s"%name, (_ndptr,),
+ {"_dtype_": datatype})
+ _pointer_type_cache[datatype] = klass
+ return klass
+
if sys.version_info < (2, 4):
# Can't set __name__ in 2.3