diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-10-04 09:41:29 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-10-04 09:42:57 -0700 |
commit | d6be04ffda311832b84959b0b5dafce59f6b9401 (patch) | |
tree | ad5504282455fdcb8b23b0153f8dbd56c99384c5 | |
parent | 2995e6a9cfb7f38f1ab8ec102cc301d3ceba480e (diff) | |
download | numpy-d6be04ffda311832b84959b0b5dafce59f6b9401.tar.gz |
BUG: Allow subclasses of MaskedConstant to behave as unique singletons
Fixes astropy/astropy#6645
-rw-r--r-- | numpy/ma/core.py | 12 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 9 |
2 files changed, 18 insertions, 3 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index b6e2edf5a..130817e7a 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -6165,8 +6165,14 @@ class MaskedConstant(MaskedArray): # the lone np.ma.masked instance __singleton = None + @classmethod + def __has_singleton(cls): + # second case ensures `cls.__singleton` is not just a view on the + # superclass singleton + return cls.__singleton is not None and type(cls.__singleton) is cls + def __new__(cls): - if cls.__singleton is None: + if not cls.__has_singleton(): # We define the masked singleton as a float for higher precedence. # Note that it can be tricky sometimes w/ type comparison data = np.array(0.) @@ -6184,7 +6190,7 @@ class MaskedConstant(MaskedArray): return cls.__singleton def __array_finalize__(self, obj): - if self.__singleton is None: + if not self.__has_singleton(): # this handles the `.view` in __new__, which we want to copy across # properties normally return super(MaskedConstant, self).__array_finalize__(obj) @@ -6207,7 +6213,7 @@ class MaskedConstant(MaskedArray): return str(masked_print_option._display) def __repr__(self): - if self is self.__singleton: + if self is MaskedConstant.__singleton: return 'masked' else: # it's a subclass, or something is wrong, make it obvious diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 41c56ca1e..f82652a24 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4880,6 +4880,15 @@ class TestMaskedConstant(object): a_b[()] = np.ma.masked assert_equal(a_b[()], b'--') + def test_subclass(self): + # https://github.com/astropy/astropy/issues/6645 + class Sub(type(np.ma.masked)): pass + + a = Sub() + assert_(a is Sub()) + assert_(a is not np.ma.masked) + assert_not_equal(repr(a), 'masked') + def test_masked_array(): a = np.ma.array([0, 1, 2, 3], mask=[0, 0, 1, 0]) |