diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 62 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 3 |
2 files changed, 53 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 3dd8e0ae4..66bc14368 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -275,19 +275,25 @@ array_remainder(PyArrayObject *m1, PyObject *m2) return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder); } -static int -array_power_is_scalar(PyObject *o2, double* out_exponent) +/* Determine if object is a scalar and if so, convert the object + * to a double and place it in the out_exponent argument + * and return the "scalar kind" as a result. If the object is + * not a scalar (or if there are other error conditions) + * return NPY_NOSCALAR, and out_exponent is undefined. + */ +static NPY_SCALARKIND +is_scalar_with_conversion(PyObject *o2, double* out_exponent) { PyObject *temp; const int optimize_fpexps = 1; if (PyInt_Check(o2)) { *out_exponent = (double)PyInt_AsLong(o2); - return 1; + return NPY_INTPOS_SCALAR; } if (optimize_fpexps && PyFloat_Check(o2)) { *out_exponent = PyFloat_AsDouble(o2); - return 1; + return NPY_FLOAT_SCALAR; } if ((PyArray_IsZeroDim(o2) && ((PyArray_ISINTEGER((PyArrayObject *)o2) || @@ -298,7 +304,20 @@ array_power_is_scalar(PyObject *o2, double* out_exponent) if (temp != NULL) { *out_exponent = PyFloat_AsDouble(o2); Py_DECREF(temp); - return 1; + if (PyArray_IsZeroDim(o2)) { + if (PyArray_ISINTEGER((PyArrayObject *)o2)) { + return NPY_INTPOS_SCALAR; + } + else { /* ISFLOAT */ + return NPY_FLOAT_SCALAR; + } + } + else if PyArray_IsScalar(o2, Integer) { + return NPY_INTPOS_SCALAR; + } + else { /* IsScalar(o2, Floating) */ + return NPY_FLOAT_SCALAR; + } } } #if (PY_VERSION_HEX >= 0x02050000) @@ -309,18 +328,18 @@ array_power_is_scalar(PyObject *o2, double* out_exponent) if (PyErr_Occurred()) { PyErr_Clear(); } - return 0; + return NPY_NOSCALAR; } val = PyInt_AsSsize_t(value); if (val == -1 && PyErr_Occurred()) { PyErr_Clear(); - return 0; + return NPY_NOSCALAR; } *out_exponent = (double) val; - return 1; + return NPY_INTPOS_SCALAR; } #endif - return 0; + return NPY_NOSCALAR; } /* optimize float array or complex array to a scalar power */ @@ -328,8 +347,9 @@ static PyObject * fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) { double exponent; + NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */ - if (PyArray_Check(a1) && array_power_is_scalar(o2, &exponent)) { + if (PyArray_Check(a1) && ((kind=is_scalar_with_conversion(o2, &exponent))>0)) { PyObject *fastop = NULL; if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) { if (exponent == 1.0) { @@ -367,6 +387,11 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) return PyArray_GenericUnaryFunction(a1, fastop); } } + /* Because this is called with all arrays, we need to + * change the output if the kind of the scalar is different + * than that of the input and inplace is not on --- + * (thus, the input should be up-cast) + */ else if (exponent == 2.0) { fastop = n_ops.multiply; if (inplace) { @@ -374,8 +399,21 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) (a1, (PyObject *)a1, fastop); } else { - return PyArray_GenericBinaryFunction - (a1, (PyObject *)a1, fastop); + PyObject *a1_conv; + PyArray_Descr *dtype=NULL; + PyObject *res; + /* We only special-case the FLOAT_SCALAR and integer types */ + if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) { + dtype = PyArray_DescrFromType(NPY_DOUBLE); + a1 = PyArray_CastToType(a1, dtype, PyArray_ISFORTRAN(a1)); + if (a1 == NULL) return NULL; + } + else { + Py_INCREF(a1); + } + res = PyArray_GenericBinaryFunction(a1, a1, fastop); + Py_DECREF(a1); + return res; } } } diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index e6e1e593f..3e17b58dc 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -128,6 +128,9 @@ class TestPower(TestCase): assert_complex_equal(np.power(zero, -p), cnan) assert_complex_equal(np.power(zero, -1+0.2j), cnan) + def test_fast_power(self): + x=np.array([1,2,3], np.int16) + assert (x**2.00001).dtype is (x**2.0).dtype class TestLog2(TestCase): def test_log2_values(self) : |