summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Goerz <goerz@stanford.edu>2016-08-15 00:31:43 -0400
committerMichael Goerz <goerz@stanford.edu>2016-08-15 00:31:43 -0400
commit60b3727e6937891a9b91bac4ffb097bb3b5bd25d (patch)
treea04f39ee547c8b88dfbdffeb7a157a08b4e7113d
parente1f191c46f2eebd6cb892a4bfe14d9dd43a06c4e (diff)
downloadnumpy-60b3727e6937891a9b91bac4ffb097bb3b5bd25d.tar.gz
ENH: improve duck typing inside iscomplexobj
Both `iscomplexobj` and `isrealobj` now try to refer to the argument's `dtype` attribute if it exists. This significantly extends the list of types for which `iscomplexobj` returns correct results (including e.g. scipy sparse matrices and pandas objects). Extended the tests of the `iscomplexobj` routine for the following cases: * simple scalars * standard lists (test internal auto-conversion to numpy arrays) * "Duck typing" for objects that define a dtype attribute (either referring to one of the existing numpy dtypes, or a custom dtype, as pandas does) This fixes #7924
-rw-r--r--numpy/lib/tests/test_type_check.py35
-rw-r--r--numpy/lib/type_check.py12
2 files changed, 45 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py
index f7430c27d..93a4da97a 100644
--- a/numpy/lib/tests/test_type_check.py
+++ b/numpy/lib/tests/test_type_check.py
@@ -148,6 +148,41 @@ class TestIscomplexobj(TestCase):
z = np.array([-1j, 0, -1])
assert_(iscomplexobj(z))
+ def test_scalar(self):
+ assert_(not iscomplexobj(1.0))
+ assert_(iscomplexobj(1+0j))
+
+ def test_list(self):
+ assert_(iscomplexobj([3, 1+0j, True]))
+ assert_(not iscomplexobj([3, 1, True]))
+
+ def test_duck(self):
+ class DummyComplexArray:
+ @property
+ def dtype(self):
+ return np.dtype(complex)
+ dummy = DummyComplexArray()
+ assert_(iscomplexobj(dummy))
+
+ def test_pandas_duck(self):
+ # This tests a custom np.dtype duck-typed class, such as used by pandas
+ # (pandas.core.dtypes)
+ class PdComplex(np.complex128):
+ pass
+ class PdDtype(object):
+ name = 'category'
+ names = None
+ type = PdComplex
+ kind = 'c'
+ str = '<c16'
+ base = np.dtype('complex128')
+ class DummyPd:
+ @property
+ def dtype(self):
+ return PdDtype
+ dummy = DummyPd()
+ assert_(iscomplexobj(dummy))
+
class TestIsrealobj(TestCase):
def test_basic(self):
diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py
index d6e0704ad..f620d49d5 100644
--- a/numpy/lib/type_check.py
+++ b/numpy/lib/type_check.py
@@ -266,7 +266,15 @@ def iscomplexobj(x):
True
"""
- return issubclass(asarray(x).dtype.type, _nx.complexfloating)
+ try:
+ dtype = x.dtype
+ except AttributeError:
+ dtype = asarray(x).dtype
+ try:
+ return issubclass(dtype.type, _nx.complexfloating)
+ except AttributeError:
+ return False
+
def isrealobj(x):
"""
@@ -300,7 +308,7 @@ def isrealobj(x):
False
"""
- return not issubclass(asarray(x).dtype.type, _nx.complexfloating)
+ return not iscomplexobj(x)
#-----------------------------------------------------------------------------