summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-11-20 09:13:35 -0600
committerGitHub <noreply@github.com>2018-11-20 09:13:35 -0600
commit70270145f3a77242e59bd3acd5e5e9577ec2a6cc (patch)
treebff0032c0ff5c6568077bb1cccd6afa6c4e55ef0
parent9eaafe540ad7f176900a644560fb19fc54e13b27 (diff)
parent8b366e0b0fff8bd46397d4d013832efce6e338b1 (diff)
downloadnumpy-70270145f3a77242e59bd3acd5e5e9577ec2a6cc.tar.gz
Merge pull request #12424 from eric-wieser/rework-11536
BUG: Fix inconsistent cache keying in ndpointer
-rw-r--r--numpy/ctypeslib.py31
-rw-r--r--numpy/tests/test_ctypeslib.py11
2 files changed, 28 insertions, 14 deletions
diff --git a/numpy/ctypeslib.py b/numpy/ctypeslib.py
index 24cfc6762..1158a5c85 100644
--- a/numpy/ctypeslib.py
+++ b/numpy/ctypeslib.py
@@ -269,8 +269,11 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
"""
+ # normalize dtype to an Optional[dtype]
if dtype is not None:
dtype = _dtype(dtype)
+
+ # normalize flags to an Optional[int]
num = None
if flags is not None:
if isinstance(flags, str):
@@ -287,10 +290,23 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
except Exception:
raise TypeError("invalid flags specification")
num = _num_fromflags(flags)
+
+ # normalize shape to an Optional[tuple]
+ if shape is not None:
+ try:
+ shape = tuple(shape)
+ except TypeError:
+ # single integer -> 1-tuple
+ shape = (shape,)
+
+ cache_key = (dtype, ndim, shape, num)
+
try:
- return _pointer_type_cache[(dtype, ndim, shape, num)]
+ return _pointer_type_cache[cache_key]
except KeyError:
pass
+
+ # produce a name for the new type
if dtype is None:
name = 'any'
elif dtype.names:
@@ -300,23 +316,16 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
if ndim is not None:
name += "_%dd" % ndim
if shape is not None:
- try:
- strshape = [str(x) for x in shape]
- except TypeError:
- strshape = [str(shape)]
- shape = (shape,)
- shape = tuple(shape)
- name += "_"+"x".join(strshape)
+ name += "_"+"x".join(str(x) for x in shape)
if flags is not None:
name += "_"+"_".join(flags)
- else:
- flags = []
+
klass = type("ndpointer_%s"%name, (_ndptr,),
{"_dtype_": dtype,
"_shape_" : shape,
"_ndim_" : ndim,
"_flags_" : num})
- _pointer_type_cache[(dtype, shape, ndim, num)] = klass
+ _pointer_type_cache[cache_key] = klass
return klass
diff --git a/numpy/tests/test_ctypeslib.py b/numpy/tests/test_ctypeslib.py
index 675f8d242..a6d73b152 100644
--- a/numpy/tests/test_ctypeslib.py
+++ b/numpy/tests/test_ctypeslib.py
@@ -108,9 +108,14 @@ class TestNdpointer(object):
assert_raises(TypeError, p.from_param, np.array([[1, 2], [3, 4]]))
def test_cache(self):
- a1 = ndpointer(dtype=np.float64)
- a2 = ndpointer(dtype=np.float64)
- assert_(a1 == a2)
+ assert_(ndpointer(dtype=np.float64) is ndpointer(dtype=np.float64))
+
+ # shapes are normalized
+ assert_(ndpointer(shape=2) is ndpointer(shape=(2,)))
+
+ # 1.12 <= v < 1.16 had a bug that made these fail
+ assert_(ndpointer(shape=2) is not ndpointer(ndim=2))
+ assert_(ndpointer(ndim=2) is not ndpointer(shape=2))
@pytest.mark.skipif(not _HAS_CTYPE,