summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2019-06-24 09:10:53 -0700
committerGitHub <noreply@github.com>2019-06-24 09:10:53 -0700
commit0a00dc9d5a6c7ad87a8117bace5a6775b0d737df (patch)
tree1fb4828019b35348443209a2a243b0f6074e6357
parent871fac7d1c08d00533f8351b8d32ad0f6e30dab1 (diff)
parent22d5415f4e5e2e7009ccd86ac7915ba43a0b7d97 (diff)
downloadnumpy-0a00dc9d5a6c7ad87a8117bace5a6775b0d737df.tar.gz
Merge pull request #13813 from mhvk/histogram2d-dispatcher-fixup
BUG: further fixup to histogram2d dispatcher.
-rw-r--r--numpy/lib/tests/test_twodim_base.py26
-rw-r--r--numpy/lib/twodim_base.py2
2 files changed, 26 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_twodim_base.py b/numpy/lib/tests/test_twodim_base.py
index bf93b4adb..bb844e4bd 100644
--- a/numpy/lib/tests/test_twodim_base.py
+++ b/numpy/lib/tests/test_twodim_base.py
@@ -5,7 +5,7 @@ from __future__ import division, absolute_import, print_function
from numpy.testing import (
assert_equal, assert_array_equal, assert_array_max_ulp,
- assert_array_almost_equal, assert_raises,
+ assert_array_almost_equal, assert_raises, assert_
)
from numpy import (
@@ -17,6 +17,9 @@ from numpy import (
import numpy as np
+from numpy.core.tests.test_overrides import requires_array_function
+
+
def get_mat(n):
data = arange(n)
data = add.outer(data, data)
@@ -273,6 +276,27 @@ class TestHistogram2d(object):
assert_array_equal(H, answer)
assert_array_equal(xe, array([0., 0.25, 0.5, 0.75, 1]))
+ @requires_array_function
+ def test_dispatch(self):
+ class ShouldDispatch:
+ def __array_function__(self, function, types, args, kwargs):
+ return types, args, kwargs
+
+ xy = [1, 2]
+ s_d = ShouldDispatch()
+ r = histogram2d(s_d, xy)
+ # Cannot use assert_equal since that dispatches...
+ assert_(r == ((ShouldDispatch,), (s_d, xy), {}))
+ r = histogram2d(xy, s_d)
+ assert_(r == ((ShouldDispatch,), (xy, s_d), {}))
+ r = histogram2d(xy, xy, bins=s_d)
+ assert_(r, ((ShouldDispatch,), (xy, xy), dict(bins=s_d)))
+ r = histogram2d(xy, xy, bins=[s_d, 5])
+ assert_(r, ((ShouldDispatch,), (xy, xy), dict(bins=[s_d, 5])))
+ assert_raises(Exception, histogram2d, xy, xy, bins=[s_d])
+ r = histogram2d(xy, xy, weights=s_d)
+ assert_(r, ((ShouldDispatch,), (xy, xy), dict(weights=s_d)))
+
class TestTri(object):
def test_dtype(self):
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 0b4e3021a..f3dc6c8e1 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -573,7 +573,7 @@ def _histogram2d_dispatcher(x, y, bins=None, range=None, normed=None,
N = len(bins)
except TypeError:
N = 1
- if N != 1 and N != 2:
+ if N == 2:
yield from bins # bins=[x, y]
else:
yield bins