diff options
author | Travis E. Oliphant <teoliphant@gmail.com> | 2012-02-12 20:53:37 -0600 |
---|---|---|
committer | Travis E. Oliphant <teoliphant@gmail.com> | 2012-02-13 00:15:41 -0600 |
commit | aa2b608d2faf417bbf0fdb018d0c2b94082c0683 (patch) | |
tree | f56db0891dc56aa3423e966c4f1279336726a17f | |
parent | 015cada57ddd00d866e2a8ef396f83582c192884 (diff) | |
download | numpy-aa2b608d2faf417bbf0fdb018d0c2b94082c0683.tar.gz |
BUG: Fix Ticket #2033 and fix fast_power behavior for integer arrays.
-rw-r--r-- | numpy/core/src/multiarray/number.c | 54 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 3 |
2 files changed, 46 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 3dd8e0ae4..9874eb8d4 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -275,7 +275,7 @@ array_remainder(PyArrayObject *m1, PyObject *m2) return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder); } -static int +static NPY_SCALARKIND array_power_is_scalar(PyObject *o2, double* out_exponent) { PyObject *temp; @@ -283,11 +283,11 @@ array_power_is_scalar(PyObject *o2, double* out_exponent) if (PyInt_Check(o2)) { *out_exponent = (double)PyInt_AsLong(o2); - return 1; + return NPY_INTPOS_SCALAR; } if (optimize_fpexps && PyFloat_Check(o2)) { *out_exponent = PyFloat_AsDouble(o2); - return 1; + return NPY_FLOAT_SCALAR; } if ((PyArray_IsZeroDim(o2) && ((PyArray_ISINTEGER((PyArrayObject *)o2) || @@ -298,7 +298,20 @@ array_power_is_scalar(PyObject *o2, double* out_exponent) if (temp != NULL) { *out_exponent = PyFloat_AsDouble(o2); Py_DECREF(temp); - return 1; + if (PyArray_IsZeroDim(o2)) { + if (PyArray_ISINTEGER((PyArrayObject *)o2)) { + return NPY_INTPOS_SCALAR; + } + else { /* ISFLOAT */ + return NPY_FLOAT_SCALAR; + } + } + else if PyArray_IsScalar(o2, Integer) { + return NPY_INTPOS_SCALAR; + } + else { /* IsScalar(o2, Floating) */ + return NPY_FLOAT_SCALAR; + } } } #if (PY_VERSION_HEX >= 0x02050000) @@ -309,18 +322,18 @@ array_power_is_scalar(PyObject *o2, double* out_exponent) if (PyErr_Occurred()) { PyErr_Clear(); } - return 0; + return NPY_NOSCALAR; } val = PyInt_AsSsize_t(value); if (val == -1 && PyErr_Occurred()) { PyErr_Clear(); - return 0; + return NPY_NOSCALAR; } *out_exponent = (double) val; - return 1; + return NPY_INTPOS_SCALAR; } #endif - return 0; + return NPY_NOSCALAR; } /* optimize float array or complex array to a scalar power */ @@ -328,8 +341,9 @@ static PyObject * fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) { double exponent; + NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */ - if (PyArray_Check(a1) && array_power_is_scalar(o2, &exponent)) { + if (PyArray_Check(a1) && ((kind=array_power_is_scalar(o2, &exponent))>0)) { PyObject *fastop = NULL; if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) { if (exponent == 1.0) { @@ -367,6 +381,11 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) return PyArray_GenericUnaryFunction(a1, fastop); } } + /* Because this is called with all arrays, we need to + * change the output if the kind of the scalar is different + * than that of the input and inplace is not on --- + * (thus, the input should be up-cast) + */ else if (exponent == 2.0) { fastop = n_ops.multiply; if (inplace) { @@ -374,8 +393,21 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace) (a1, (PyObject *)a1, fastop); } else { - return PyArray_GenericBinaryFunction - (a1, (PyObject *)a1, fastop); + PyObject *a1_conv; + PyArray_Descr *dtype=NULL; + PyObject *res; + /* We only special-case the FLOAT_SCALAR and integer types */ + if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) { + dtype = PyArray_DescrFromType(NPY_DOUBLE); + a1 = PyArray_CastToType(a1, dtype, PyArray_ISFORTRAN(a1)); + if (a1 == NULL) return NULL; + } + else { + Py_INCREF(a1); + } + res = PyArray_GenericBinaryFunction(a1, a1, fastop); + Py_DECREF(a1); + return res; } } } diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index e6e1e593f..3e17b58dc 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -128,6 +128,9 @@ class TestPower(TestCase): assert_complex_equal(np.power(zero, -p), cnan) assert_complex_equal(np.power(zero, -1+0.2j), cnan) + def test_fast_power(self): + x=np.array([1,2,3], np.int16) + assert (x**2.00001).dtype is (x**2.0).dtype class TestLog2(TestCase): def test_log2_values(self) : |