summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-11-19 23:30:31 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-11-19 23:30:31 -0800
commit8b366e0b0fff8bd46397d4d013832efce6e338b1 (patch)
tree25468b3153c4f931e6b2038ff730e1e656a6f0d3
parent1466e788a43b8d4356fe35951bf0c3b0aedb554f (diff)
downloadnumpy-8b366e0b0fff8bd46397d4d013832efce6e338b1.tar.gz
BUG: Fix inconsistent cache keying in ndpointer
Fixes an alarming bug introduced in gh-7311 (1.12) where the following is true np.ctypeslib.ndpointer(ndim=2) is np.ctypeslib.ndpointer(shape=2) Rework of gh-11536
-rw-r--r--numpy/ctypeslib.py31
-rw-r--r--numpy/tests/test_ctypeslib.py11
2 files changed, 28 insertions, 14 deletions
diff --git a/numpy/ctypeslib.py b/numpy/ctypeslib.py
index 24cfc6762..1158a5c85 100644
--- a/numpy/ctypeslib.py
+++ b/numpy/ctypeslib.py
@@ -269,8 +269,11 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
"""
+ # normalize dtype to an Optional[dtype]
if dtype is not None:
dtype = _dtype(dtype)
+
+ # normalize flags to an Optional[int]
num = None
if flags is not None:
if isinstance(flags, str):
@@ -287,10 +290,23 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
except Exception:
raise TypeError("invalid flags specification")
num = _num_fromflags(flags)
+
+ # normalize shape to an Optional[tuple]
+ if shape is not None:
+ try:
+ shape = tuple(shape)
+ except TypeError:
+ # single integer -> 1-tuple
+ shape = (shape,)
+
+ cache_key = (dtype, ndim, shape, num)
+
try:
- return _pointer_type_cache[(dtype, ndim, shape, num)]
+ return _pointer_type_cache[cache_key]
except KeyError:
pass
+
+ # produce a name for the new type
if dtype is None:
name = 'any'
elif dtype.names:
@@ -300,23 +316,16 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
if ndim is not None:
name += "_%dd" % ndim
if shape is not None:
- try:
- strshape = [str(x) for x in shape]
- except TypeError:
- strshape = [str(shape)]
- shape = (shape,)
- shape = tuple(shape)
- name += "_"+"x".join(strshape)
+ name += "_"+"x".join(str(x) for x in shape)
if flags is not None:
name += "_"+"_".join(flags)
- else:
- flags = []
+
klass = type("ndpointer_%s"%name, (_ndptr,),
{"_dtype_": dtype,
"_shape_" : shape,
"_ndim_" : ndim,
"_flags_" : num})
- _pointer_type_cache[(dtype, shape, ndim, num)] = klass
+ _pointer_type_cache[cache_key] = klass
return klass
diff --git a/numpy/tests/test_ctypeslib.py b/numpy/tests/test_ctypeslib.py
index 675f8d242..a6d73b152 100644
--- a/numpy/tests/test_ctypeslib.py
+++ b/numpy/tests/test_ctypeslib.py
@@ -108,9 +108,14 @@ class TestNdpointer(object):
assert_raises(TypeError, p.from_param, np.array([[1, 2], [3, 4]]))
def test_cache(self):
- a1 = ndpointer(dtype=np.float64)
- a2 = ndpointer(dtype=np.float64)
- assert_(a1 == a2)
+ assert_(ndpointer(dtype=np.float64) is ndpointer(dtype=np.float64))
+
+ # shapes are normalized
+ assert_(ndpointer(shape=2) is ndpointer(shape=(2,)))
+
+ # 1.12 <= v < 1.16 had a bug that made these fail
+ assert_(ndpointer(shape=2) is not ndpointer(ndim=2))
+ assert_(ndpointer(ndim=2) is not ndpointer(shape=2))
@pytest.mark.skipif(not _HAS_CTYPE,