summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2017-04-05 12:51:57 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2017-04-06 20:50:25 +0200
commit9a2f0fec330861f8f60b805a2577eb3277ce5787 (patch)
tree2bec95e2cbed081b9d24eb984a8b591a306f4d41 /numpy
parentd21cecf4526fc463a063d058d92987633250f06c (diff)
downloadnumpy-9a2f0fec330861f8f60b805a2577eb3277ce5787.tar.gz
ENH: do integer**2. inplace
Squaring integer arrays with float argument always casts to double so the squaring itself can be done inplace. Also use square fastop instead of multiply.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/number.c20
-rw-r--r--numpy/core/tests/test_umath.py7
2 files changed, 15 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 62e8e3884..f0b5637d2 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -583,30 +583,28 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
* (thus, the input should be up-cast)
*/
else if (exponent == 2.0) {
- fastop = n_ops.multiply;
+ fastop = n_ops.square;
if (inplace) {
- return PyArray_GenericInplaceBinaryFunction
- (a1, (PyObject *)a1, fastop);
+ return PyArray_GenericInplaceUnaryFunction(a1, fastop);
}
else {
- PyArray_Descr *dtype = NULL;
- PyObject *res;
-
/* We only special-case the FLOAT_SCALAR and integer types */
if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
- dtype = PyArray_DescrFromType(NPY_DOUBLE);
+ PyObject *res;
+ PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE);
a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype,
PyArray_ISFORTRAN(a1));
if (a1 == NULL) {
return NULL;
}
+ /* cast always creates a new array */
+ res = PyArray_GenericInplaceUnaryFunction(a1, fastop);
+ Py_DECREF(a1);
+ return res;
}
else {
- Py_INCREF(a1);
+ return PyArray_GenericUnaryFunction(a1, fastop);
}
- res = PyArray_GenericBinaryFunction(a1, (PyObject *)a1, fastop);
- Py_DECREF(a1);
- return res;
}
}
}
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 326ce2279..21ac4eda3 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -456,7 +456,12 @@ class TestPower(TestCase):
def test_fast_power(self):
x = np.array([1, 2, 3], np.int16)
- assert_((x**2.00001).dtype is (x**2.0).dtype)
+ res = x**2.0
+ assert_((x**2.00001).dtype is res.dtype)
+ assert_array_equal(res, [1, 4, 9])
+ # check the inplace operation on the casted copy doesn't mess with x
+ assert_(not np.may_share_memory(res, x))
+ assert_array_equal(x, [1, 2, 3])
# Check that the fast path ignores 1-element not 0-d arrays
res = x ** np.array([[[2]]])