summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/number.c62
-rw-r--r--numpy/core/tests/test_umath.py3
2 files changed, 53 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 3dd8e0ae4..66bc14368 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -275,19 +275,25 @@ array_remainder(PyArrayObject *m1, PyObject *m2)
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
}
-static int
-array_power_is_scalar(PyObject *o2, double* out_exponent)
+/* Determine if object is a scalar and if so, convert the object
+ * to a double and place it in the out_exponent argument
+ * and return the "scalar kind" as a result. If the object is
+ * not a scalar (or if there are other error conditions)
+ * return NPY_NOSCALAR, and out_exponent is undefined.
+ */
+static NPY_SCALARKIND
+is_scalar_with_conversion(PyObject *o2, double* out_exponent)
{
PyObject *temp;
const int optimize_fpexps = 1;
if (PyInt_Check(o2)) {
*out_exponent = (double)PyInt_AsLong(o2);
- return 1;
+ return NPY_INTPOS_SCALAR;
}
if (optimize_fpexps && PyFloat_Check(o2)) {
*out_exponent = PyFloat_AsDouble(o2);
- return 1;
+ return NPY_FLOAT_SCALAR;
}
if ((PyArray_IsZeroDim(o2) &&
((PyArray_ISINTEGER((PyArrayObject *)o2) ||
@@ -298,7 +304,20 @@ array_power_is_scalar(PyObject *o2, double* out_exponent)
if (temp != NULL) {
*out_exponent = PyFloat_AsDouble(o2);
Py_DECREF(temp);
- return 1;
+ if (PyArray_IsZeroDim(o2)) {
+ if (PyArray_ISINTEGER((PyArrayObject *)o2)) {
+ return NPY_INTPOS_SCALAR;
+ }
+ else { /* ISFLOAT */
+ return NPY_FLOAT_SCALAR;
+ }
+ }
+ else if PyArray_IsScalar(o2, Integer) {
+ return NPY_INTPOS_SCALAR;
+ }
+ else { /* IsScalar(o2, Floating) */
+ return NPY_FLOAT_SCALAR;
+ }
}
}
#if (PY_VERSION_HEX >= 0x02050000)
@@ -309,18 +328,18 @@ array_power_is_scalar(PyObject *o2, double* out_exponent)
if (PyErr_Occurred()) {
PyErr_Clear();
}
- return 0;
+ return NPY_NOSCALAR;
}
val = PyInt_AsSsize_t(value);
if (val == -1 && PyErr_Occurred()) {
PyErr_Clear();
- return 0;
+ return NPY_NOSCALAR;
}
*out_exponent = (double) val;
- return 1;
+ return NPY_INTPOS_SCALAR;
}
#endif
- return 0;
+ return NPY_NOSCALAR;
}
/* optimize float array or complex array to a scalar power */
@@ -328,8 +347,9 @@ static PyObject *
fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
{
double exponent;
+ NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
- if (PyArray_Check(a1) && array_power_is_scalar(o2, &exponent)) {
+ if (PyArray_Check(a1) && ((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
PyObject *fastop = NULL;
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
if (exponent == 1.0) {
@@ -367,6 +387,11 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
return PyArray_GenericUnaryFunction(a1, fastop);
}
}
+ /* Because this is called with all arrays, we need to
+ * change the output if the kind of the scalar is different
+ * than that of the input and inplace is not on ---
+ * (thus, the input should be up-cast)
+ */
else if (exponent == 2.0) {
fastop = n_ops.multiply;
if (inplace) {
@@ -374,8 +399,21 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
(a1, (PyObject *)a1, fastop);
}
else {
- return PyArray_GenericBinaryFunction
- (a1, (PyObject *)a1, fastop);
+ PyObject *a1_conv;
+ PyArray_Descr *dtype=NULL;
+ PyObject *res;
+ /* We only special-case the FLOAT_SCALAR and integer types */
+ if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
+ dtype = PyArray_DescrFromType(NPY_DOUBLE);
+ a1 = PyArray_CastToType(a1, dtype, PyArray_ISFORTRAN(a1));
+ if (a1 == NULL) return NULL;
+ }
+ else {
+ Py_INCREF(a1);
+ }
+ res = PyArray_GenericBinaryFunction(a1, a1, fastop);
+ Py_DECREF(a1);
+ return res;
}
}
}
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index e6e1e593f..3e17b58dc 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -128,6 +128,9 @@ class TestPower(TestCase):
assert_complex_equal(np.power(zero, -p), cnan)
assert_complex_equal(np.power(zero, -1+0.2j), cnan)
+ def test_fast_power(self):
+ x=np.array([1,2,3], np.int16)
+ assert (x**2.00001).dtype is (x**2.0).dtype
class TestLog2(TestCase):
def test_log2_values(self) :