summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis E. Oliphant <teoliphant@gmail.com>2012-02-12 20:53:37 -0600
committerTravis E. Oliphant <teoliphant@gmail.com>2012-02-13 00:15:41 -0600
commitaa2b608d2faf417bbf0fdb018d0c2b94082c0683 (patch)
treef56db0891dc56aa3423e966c4f1279336726a17f
parent015cada57ddd00d866e2a8ef396f83582c192884 (diff)
downloadnumpy-aa2b608d2faf417bbf0fdb018d0c2b94082c0683.tar.gz
BUG: Fix Ticket #2033 and fix fast_power behavior for integer arrays.
-rw-r--r--numpy/core/src/multiarray/number.c54
-rw-r--r--numpy/core/tests/test_umath.py3
2 files changed, 46 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 3dd8e0ae4..9874eb8d4 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -275,7 +275,7 @@ array_remainder(PyArrayObject *m1, PyObject *m2)
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
}
-static int
+static NPY_SCALARKIND
array_power_is_scalar(PyObject *o2, double* out_exponent)
{
PyObject *temp;
@@ -283,11 +283,11 @@ array_power_is_scalar(PyObject *o2, double* out_exponent)
if (PyInt_Check(o2)) {
*out_exponent = (double)PyInt_AsLong(o2);
- return 1;
+ return NPY_INTPOS_SCALAR;
}
if (optimize_fpexps && PyFloat_Check(o2)) {
*out_exponent = PyFloat_AsDouble(o2);
- return 1;
+ return NPY_FLOAT_SCALAR;
}
if ((PyArray_IsZeroDim(o2) &&
((PyArray_ISINTEGER((PyArrayObject *)o2) ||
@@ -298,7 +298,20 @@ array_power_is_scalar(PyObject *o2, double* out_exponent)
if (temp != NULL) {
*out_exponent = PyFloat_AsDouble(o2);
Py_DECREF(temp);
- return 1;
+ if (PyArray_IsZeroDim(o2)) {
+ if (PyArray_ISINTEGER((PyArrayObject *)o2)) {
+ return NPY_INTPOS_SCALAR;
+ }
+ else { /* ISFLOAT */
+ return NPY_FLOAT_SCALAR;
+ }
+ }
+ else if PyArray_IsScalar(o2, Integer) {
+ return NPY_INTPOS_SCALAR;
+ }
+ else { /* IsScalar(o2, Floating) */
+ return NPY_FLOAT_SCALAR;
+ }
}
}
#if (PY_VERSION_HEX >= 0x02050000)
@@ -309,18 +322,18 @@ array_power_is_scalar(PyObject *o2, double* out_exponent)
if (PyErr_Occurred()) {
PyErr_Clear();
}
- return 0;
+ return NPY_NOSCALAR;
}
val = PyInt_AsSsize_t(value);
if (val == -1 && PyErr_Occurred()) {
PyErr_Clear();
- return 0;
+ return NPY_NOSCALAR;
}
*out_exponent = (double) val;
- return 1;
+ return NPY_INTPOS_SCALAR;
}
#endif
- return 0;
+ return NPY_NOSCALAR;
}
/* optimize float array or complex array to a scalar power */
@@ -328,8 +341,9 @@ static PyObject *
fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
{
double exponent;
+ NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
- if (PyArray_Check(a1) && array_power_is_scalar(o2, &exponent)) {
+ if (PyArray_Check(a1) && ((kind=array_power_is_scalar(o2, &exponent))>0)) {
PyObject *fastop = NULL;
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
if (exponent == 1.0) {
@@ -367,6 +381,11 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
return PyArray_GenericUnaryFunction(a1, fastop);
}
}
+ /* Because this is called with all arrays, we need to
+ * change the output if the kind of the scalar is different
+ * than that of the input and inplace is not on ---
+ * (thus, the input should be up-cast)
+ */
else if (exponent == 2.0) {
fastop = n_ops.multiply;
if (inplace) {
@@ -374,8 +393,21 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
(a1, (PyObject *)a1, fastop);
}
else {
- return PyArray_GenericBinaryFunction
- (a1, (PyObject *)a1, fastop);
+ PyObject *a1_conv;
+ PyArray_Descr *dtype=NULL;
+ PyObject *res;
+ /* We only special-case the FLOAT_SCALAR and integer types */
+ if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
+ dtype = PyArray_DescrFromType(NPY_DOUBLE);
+ a1 = PyArray_CastToType(a1, dtype, PyArray_ISFORTRAN(a1));
+ if (a1 == NULL) return NULL;
+ }
+ else {
+ Py_INCREF(a1);
+ }
+ res = PyArray_GenericBinaryFunction(a1, a1, fastop);
+ Py_DECREF(a1);
+ return res;
}
}
}
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index e6e1e593f..3e17b58dc 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -128,6 +128,9 @@ class TestPower(TestCase):
assert_complex_equal(np.power(zero, -p), cnan)
assert_complex_equal(np.power(zero, -1+0.2j), cnan)
+ def test_fast_power(self):
+ x=np.array([1,2,3], np.int16)
+ assert (x**2.00001).dtype is (x**2.0).dtype
class TestLog2(TestCase):
def test_log2_values(self) :