diff options
author | fivemok <9394929+fivemok@users.noreply.github.com> | 2018-05-15 13:33:05 -0400 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-05-15 10:33:05 -0700 |
commit | 290aa88e50e65c2dd5df4ddb062b895e9e7baae3 (patch) | |
tree | 6447ee92779f9e852b768b473e97b0d7d86ad801 /numpy | |
parent | 79ec65de3064cbdc687820ab09f20b989028509d (diff) | |
download | numpy-290aa88e50e65c2dd5df4ddb062b895e9e7baae3.tar.gz |
BUG: remove fast scalar power for arrays with object dtype (#11050)
Fixes #11047
If objects want to optimize `x**2` as `x*x`, they can do it inside their own `__pow__`, rather than relying on numpy to do it.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 32 |
2 files changed, 35 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 915d743c8..14389a925 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -476,7 +476,9 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace, double exponent; NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */ - if (PyArray_Check(a1) && ((kind=is_scalar_with_conversion(o2, &exponent))>0)) { + if (PyArray_Check(a1) && + !PyArray_ISOBJECT(a1) && + ((kind=is_scalar_with_conversion(o2, &exponent))>0)) { PyObject *fastop = NULL; if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) { if (exponent == 1.0) { diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 46bdc19a7..3ca201edd 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3317,7 +3317,39 @@ class TestBinop(object): with assert_raises(NotImplementedError): a ** 2 + def test_pow_array_object_dtype(self): + # test pow on arrays of object dtype + class SomeClass(object): + def __init__(self, num=None): + self.num = num + + # want to ensure a fast pow path is not taken + def __mul__(self, other): + raise AssertionError('__mul__ should not be called') + + def __div__(self, other): + raise AssertionError('__div__ should not be called') + + def __pow__(self, exp): + return SomeClass(num=self.num ** exp) + + def __eq__(self, other): + if isinstance(other, SomeClass): + return self.num == other.num + + __rpow__ = __pow__ + + def pow_for(exp, arr): + return np.array([x ** exp for x in arr]) + + obj_arr = np.array([SomeClass(1), SomeClass(2), SomeClass(3)]) + assert_equal(obj_arr ** 0.5, pow_for(0.5, obj_arr)) + assert_equal(obj_arr ** 0, pow_for(0, obj_arr)) + assert_equal(obj_arr ** 1, pow_for(1, obj_arr)) + assert_equal(obj_arr ** -1, pow_for(-1, obj_arr)) + assert_equal(obj_arr ** 2, pow_for(2, obj_arr)) + class TestTemporaryElide(object): # elision is only triggered on relatively large arrays |