diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-11-20 09:13:35 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-11-20 09:13:35 -0600 |
commit | 70270145f3a77242e59bd3acd5e5e9577ec2a6cc (patch) | |
tree | bff0032c0ff5c6568077bb1cccd6afa6c4e55ef0 | |
parent | 9eaafe540ad7f176900a644560fb19fc54e13b27 (diff) | |
parent | 8b366e0b0fff8bd46397d4d013832efce6e338b1 (diff) | |
download | numpy-70270145f3a77242e59bd3acd5e5e9577ec2a6cc.tar.gz |
Merge pull request #12424 from eric-wieser/rework-11536
BUG: Fix inconsistent cache keying in ndpointer
-rw-r--r-- | numpy/ctypeslib.py | 31 | ||||
-rw-r--r-- | numpy/tests/test_ctypeslib.py | 11 |
2 files changed, 28 insertions, 14 deletions
diff --git a/numpy/ctypeslib.py b/numpy/ctypeslib.py index 24cfc6762..1158a5c85 100644 --- a/numpy/ctypeslib.py +++ b/numpy/ctypeslib.py @@ -269,8 +269,11 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None): """ + # normalize dtype to an Optional[dtype] if dtype is not None: dtype = _dtype(dtype) + + # normalize flags to an Optional[int] num = None if flags is not None: if isinstance(flags, str): @@ -287,10 +290,23 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None): except Exception: raise TypeError("invalid flags specification") num = _num_fromflags(flags) + + # normalize shape to an Optional[tuple] + if shape is not None: + try: + shape = tuple(shape) + except TypeError: + # single integer -> 1-tuple + shape = (shape,) + + cache_key = (dtype, ndim, shape, num) + try: - return _pointer_type_cache[(dtype, ndim, shape, num)] + return _pointer_type_cache[cache_key] except KeyError: pass + + # produce a name for the new type if dtype is None: name = 'any' elif dtype.names: @@ -300,23 +316,16 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None): if ndim is not None: name += "_%dd" % ndim if shape is not None: - try: - strshape = [str(x) for x in shape] - except TypeError: - strshape = [str(shape)] - shape = (shape,) - shape = tuple(shape) - name += "_"+"x".join(strshape) + name += "_"+"x".join(str(x) for x in shape) if flags is not None: name += "_"+"_".join(flags) - else: - flags = [] + klass = type("ndpointer_%s"%name, (_ndptr,), {"_dtype_": dtype, "_shape_" : shape, "_ndim_" : ndim, "_flags_" : num}) - _pointer_type_cache[(dtype, shape, ndim, num)] = klass + _pointer_type_cache[cache_key] = klass return klass diff --git a/numpy/tests/test_ctypeslib.py b/numpy/tests/test_ctypeslib.py index 675f8d242..a6d73b152 100644 --- a/numpy/tests/test_ctypeslib.py +++ b/numpy/tests/test_ctypeslib.py @@ -108,9 +108,14 @@ class TestNdpointer(object): assert_raises(TypeError, p.from_param, np.array([[1, 2], [3, 4]])) def test_cache(self): - a1 = ndpointer(dtype=np.float64) - a2 = ndpointer(dtype=np.float64) - assert_(a1 == a2) + assert_(ndpointer(dtype=np.float64) is ndpointer(dtype=np.float64)) + + # shapes are normalized + assert_(ndpointer(shape=2) is ndpointer(shape=(2,))) + + # 1.12 <= v < 1.16 had a bug that made these fail + assert_(ndpointer(shape=2) is not ndpointer(ndim=2)) + assert_(ndpointer(ndim=2) is not ndpointer(shape=2)) @pytest.mark.skipif(not _HAS_CTYPE, |