summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorfivemok <9394929+fivemok@users.noreply.github.com>2018-05-15 13:33:05 -0400
committerEric Wieser <wieser.eric@gmail.com>2018-05-15 10:33:05 -0700
commit290aa88e50e65c2dd5df4ddb062b895e9e7baae3 (patch)
tree6447ee92779f9e852b768b473e97b0d7d86ad801 /numpy
parent79ec65de3064cbdc687820ab09f20b989028509d (diff)
downloadnumpy-290aa88e50e65c2dd5df4ddb062b895e9e7baae3.tar.gz
BUG: remove fast scalar power for arrays with object dtype (#11050)
Fixes #11047 If objects want to optimize `x**2` as `x*x`, they can do it inside their own `__pow__`, rather than relying on numpy to do it.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/number.c4
-rw-r--r--numpy/core/tests/test_multiarray.py32
2 files changed, 35 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 915d743c8..14389a925 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -476,7 +476,9 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace,
double exponent;
NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
- if (PyArray_Check(a1) && ((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
+ if (PyArray_Check(a1) &&
+ !PyArray_ISOBJECT(a1) &&
+ ((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
PyObject *fastop = NULL;
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
if (exponent == 1.0) {
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 46bdc19a7..3ca201edd 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -3317,7 +3317,39 @@ class TestBinop(object):
with assert_raises(NotImplementedError):
a ** 2
+ def test_pow_array_object_dtype(self):
+ # test pow on arrays of object dtype
+ class SomeClass(object):
+ def __init__(self, num=None):
+ self.num = num
+
+ # want to ensure a fast pow path is not taken
+ def __mul__(self, other):
+ raise AssertionError('__mul__ should not be called')
+
+ def __div__(self, other):
+ raise AssertionError('__div__ should not be called')
+
+ def __pow__(self, exp):
+ return SomeClass(num=self.num ** exp)
+
+ def __eq__(self, other):
+ if isinstance(other, SomeClass):
+ return self.num == other.num
+
+ __rpow__ = __pow__
+
+ def pow_for(exp, arr):
+ return np.array([x ** exp for x in arr])
+
+ obj_arr = np.array([SomeClass(1), SomeClass(2), SomeClass(3)])
+ assert_equal(obj_arr ** 0.5, pow_for(0.5, obj_arr))
+ assert_equal(obj_arr ** 0, pow_for(0, obj_arr))
+ assert_equal(obj_arr ** 1, pow_for(1, obj_arr))
+ assert_equal(obj_arr ** -1, pow_for(-1, obj_arr))
+ assert_equal(obj_arr ** 2, pow_for(2, obj_arr))
+
class TestTemporaryElide(object):
# elision is only triggered on relatively large arrays