summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/code_generators/numpy_api.py19
-rw-r--r--numpy/core/src/multiarray/einsum.c.src1253
-rw-r--r--numpy/core/src/multiarray/new_iterator.c.src122
-rw-r--r--numpy/core/tests/test_numeric.py119
4 files changed, 1408 insertions, 105 deletions
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py
index b943f17de..4620766cc 100644
--- a/numpy/core/code_generators/numpy_api.py
+++ b/numpy/core/code_generators/numpy_api.py
@@ -292,16 +292,17 @@ multiarray_funcs_api = {
'NpyIter_GetWriteFlags': 258,
'NpyIter_DebugPrint': 259,
'NpyIter_IterationNeedsAPI': 260,
+ 'NpyIter_GetInnerFixedStrideArray': 261,
#
- 'PyArray_CastingConverter': 261,
- 'PyArray_CountNonzero': 262,
- 'PyArray_PromoteTypes': 263,
- 'PyArray_MinScalarType': 264,
- 'PyArray_ResultType': 265,
- 'PyArray_CanCastArrayTo': 266,
- 'PyArray_CanCastTypeTo': 267,
- 'PyArray_EinsteinSum': 268,
- 'PyArray_FillWithZero': 269,
+ 'PyArray_CastingConverter': 262,
+ 'PyArray_CountNonzero': 263,
+ 'PyArray_PromoteTypes': 264,
+ 'PyArray_MinScalarType': 265,
+ 'PyArray_ResultType': 266,
+ 'PyArray_CanCastArrayTo': 267,
+ 'PyArray_CanCastTypeTo': 268,
+ 'PyArray_EinsteinSum': 269,
+ 'PyArray_FillWithZero': 270,
}
ufunc_types_api = {
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 4ebd3aa82..d25c6af90 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -18,6 +18,32 @@
#include <ctype.h>
+#ifdef __SSE__
+#define EINSUM_USE_SSE1 1
+#else
+#define EINSUM_USE_SSE1 0
+#endif
+
+/*
+ * TODO: Only SSE for float32 is implemented in the loops,
+ * no SSE2 for float64
+ */
+#ifdef __SSE2__
+#define EINSUM_USE_SSE2 0
+#else
+#define EINSUM_USE_SSE2 0
+#endif
+
+#if EINSUM_USE_SSE1
+#include <xmmintrin.h>
+#endif
+
+#if EINSUM_USE_SSE2
+#include <emmintrin.h>
+#endif
+
+#define EINSUM_IS_SSE_ALIGNED(x) ((((npy_intp)x)&0xf) == 0)
+
typedef enum {
BROADCAST_LEFT,
BROADCAST_RIGHT,
@@ -45,6 +71,10 @@ typedef enum {
* 0*5,
* 0*4,
* 1*3#
+ * #float32 = 0*5,
+ * 0*5,
+ * 0,1,0,0,
+ * 0*3#
*/
/**begin repeat1
@@ -55,29 +85,46 @@ static void
@name@_sum_of_products_@noplabel@(int nop, char **dataptr,
npy_intp *strides, npy_intp count)
{
+#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@)
+ char *data0 = dataptr[0];
+ npy_intp stride0 = strides[0];
+#endif
+#if (@nop@ == 2 || @nop@ == 3) && !@complex@
+ char *data1 = dataptr[1];
+ npy_intp stride1 = strides[1];
+#endif
+#if (@nop@ == 3) && !@complex@
+ char *data2 = dataptr[2];
+ npy_intp stride2 = strides[2];
+#endif
+#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@)
+ char *data_out = dataptr[@nop@];
+ npy_intp stride_out = strides[@nop@];
+#endif
+
while (count--) {
#if !@complex@
# if @nop@ == 1
- *(npy_@name@ *)dataptr[1] = @to@(@from@(*(npy_@name@ *)dataptr[0]) +
- @from@(*(npy_@name@ *)dataptr[1]));
- dataptr[0] += strides[0];
- dataptr[1] += strides[1];
+ *(npy_@name@ *)data_out = @to@(@from@(*(npy_@name@ *)data0) +
+ @from@(*(npy_@name@ *)data_out));
+ data0 += stride0;
+ data_out += stride_out;
# elif @nop@ == 2
- *(npy_@name@ *)dataptr[2] = @to@(@from@(*(npy_@name@ *)dataptr[0]) *
- @from@(*(npy_@name@ *)dataptr[1]) +
- @from@(*(npy_@name@ *)dataptr[2]));
- dataptr[0] += strides[0];
- dataptr[1] += strides[1];
- dataptr[2] += strides[2];
+ *(npy_@name@ *)data_out = @to@(@from@(*(npy_@name@ *)data0) *
+ @from@(*(npy_@name@ *)data1) +
+ @from@(*(npy_@name@ *)data_out));
+ data0 += stride0;
+ data1 += stride1;
+ data_out += stride_out;
# elif @nop@ == 3
- *(npy_@name@ *)dataptr[3] = @to@(@from@(*(npy_@name@ *)dataptr[0]) *
- @from@(*(npy_@name@ *)dataptr[1]) *
- @from@(*(npy_@name@ *)dataptr[2]) +
- @from@(*(npy_@name@ *)dataptr[3]));
- dataptr[0] += strides[0];
- dataptr[1] += strides[1];
- dataptr[2] += strides[2];
- dataptr[3] += strides[3];
+ *(npy_@name@ *)data_out = @to@(@from@(*(npy_@name@ *)data0) *
+ @from@(*(npy_@name@ *)data1) *
+ @from@(*(npy_@name@ *)data2) +
+ @from@(*(npy_@name@ *)data_out));
+ data0 += stride0;
+ data1 += stride1;
+ data2 += stride2;
+ data_out += stride_out;
# else
npy_@temp@ temp = @from@(*(npy_@name@ *)dataptr[0]);
int i;
@@ -92,12 +139,12 @@ static void
# endif
#else /* complex */
# if @nop@ == 1
- ((npy_@temp@ *)dataptr[1])[0] = ((npy_@temp@ *)dataptr[0])[0] +
- ((npy_@temp@ *)dataptr[1])[0];
- ((npy_@temp@ *)dataptr[1])[1] = ((npy_@temp@ *)dataptr[0])[1] +
- ((npy_@temp@ *)dataptr[1])[1];
- dataptr[0] += strides[0];
- dataptr[1] += strides[1];
+ ((npy_@temp@ *)data_out)[0] = ((npy_@temp@ *)data0)[0] +
+ ((npy_@temp@ *)data_out)[0];
+ ((npy_@temp@ *)data_out)[1] = ((npy_@temp@ *)data0)[1] +
+ ((npy_@temp@ *)data_out)[1];
+ data0 += stride0;
+ data_out += stride_out;
# else
# if @nop@ <= 3
#define _SUMPROD_NOP @nop@
@@ -108,7 +155,7 @@ static void
re = ((npy_@temp@ *)dataptr[0])[0];
im = ((npy_@temp@ *)dataptr[0])[1];
int i;
- for (i = 1; i <= _SUMPROD_NOP; ++i) {
+ for (i = 1; i < _SUMPROD_NOP; ++i) {
tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
im * ((npy_@temp@ *)dataptr[i])[1];
im = re * ((npy_@temp@ *)dataptr[i])[1] +
@@ -129,6 +176,723 @@ static void
}
}
+#if @nop@ == 1
+
+static void
+@name@_sum_of_products_contig_one(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
+ npy_@name@ *data_out = (npy_@name@ *)dataptr[1];
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+#if !@complex@
+ data_out[@i@] = @to@(@from@(data0[@i@]) +
+ @from@(data_out[@i@]));
+#else /* complex */
+ ((npy_@temp@ *)data_out + 2*@i@)[0] =
+ ((npy_@temp@ *)data0 + 2*@i@)[0] +
+ ((npy_@temp@ *)data_out + 2*@i@)[0];
+ ((npy_@temp@ *)data_out + 2*@i@)[1] =
+ ((npy_@temp@ *)data0 + 2*@i@)[1] +
+ ((npy_@temp@ *)data_out + 2*@i@)[1];
+#endif
+ data0 += 16;
+ data_out += 16;
+/**end repeat2**/
+ }
+
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+#if !@complex@
+ data_out[@i@] = @to@(@from@(data0[@i@]) +
+ @from@(data_out[@i@]));
+#else
+ ((npy_@temp@ *)data_out + 2*@i@)[0] =
+ ((npy_@temp@ *)data0 + 2*@i@)[0] +
+ ((npy_@temp@ *)data_out + 2*@i@)[0];
+ ((npy_@temp@ *)data_out + 2*@i@)[1] =
+ ((npy_@temp@ *)data0 + 2*@i@)[1] +
+ ((npy_@temp@ *)data_out + 2*@i@)[1];
+#endif
+/**end repeat2**/
+}
+
+#elif @nop@ == 2 && !@complex@
+
+static void
+@name@_sum_of_products_contig_two(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
+ npy_@name@ *data1 = (npy_@name@ *)dataptr[1];
+ npy_@name@ *data_out = (npy_@name@ *)dataptr[2];
+
+#if EINSUM_USE_SSE1 && @float32@
+ __m128 a, b;
+#endif
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) &&
+ EINSUM_IS_SSE_ALIGNED(data_out)) {
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@));
+ b = _mm_add_ps(a, _mm_load_ps(data_out+@i@));
+ _mm_store_ps(data_out+@i@, b);
+/**end repeat2**/
+ data0 += 16;
+ data1 += 16;
+ data_out += 16;
+ }
+ }
+#endif
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+#if EINSUM_USE_SSE1 && @float32@
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@));
+ b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@));
+ _mm_storeu_ps(data_out+@i@, b);
+/**end repeat2**/
+#else
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ data_out[@i@] = @to@(@from@(data0[@i@]) *
+ @from@(data1[@i@]) +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+#endif
+ data0 += 16;
+ data1 += 16;
+ data_out += 16;
+ }
+
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+ data_out[@i@] = @to@(@from@(data0[@i@]) *
+ @from@(data1[@i@]) +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+}
+
+/* Some extra specializations for the two operand case */
+static void
+@name@_sum_of_products_stride0_contig_outcontig_two(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@temp@ value0 = @from@(*(npy_@name@ *)dataptr[0]);
+ npy_@name@ *data1 = (npy_@name@ *)dataptr[1];
+ npy_@name@ *data_out = (npy_@name@ *)dataptr[2];
+
+#if EINSUM_USE_SSE1 && @float32@
+ __m128 a, b, value0_sse;
+
+ value0_sse = _mm_set_ps1(value0);
+#endif
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) {
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ a = _mm_mul_ps(value0_sse, _mm_load_ps(data1+@i@));
+ b = _mm_add_ps(a, _mm_load_ps(data_out+@i@));
+ _mm_store_ps(data_out+@i@, b);
+/**end repeat2**/
+ data1 += 16;
+ data_out += 16;
+ }
+ }
+#endif
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+#if EINSUM_USE_SSE1 && @float32@
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ a = _mm_mul_ps(value0_sse, _mm_loadu_ps(data1+@i@));
+ b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@));
+ _mm_storeu_ps(data_out+@i@, b);
+/**end repeat2**/
+#else
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ data_out[@i@] = @to@(value0 *
+ @from@(data1[@i@]) +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+#endif
+ data1 += 16;
+ data_out += 16;
+ }
+
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+ data_out[@i@] = @to@(value0 *
+ @from@(data1[@i@]) +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+}
+
+static void
+@name@_sum_of_products_contig_stride0_outcontig_two(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
+ npy_@temp@ value1 = @from@(*(npy_@name@ *)dataptr[1]);
+ npy_@name@ *data_out = (npy_@name@ *)dataptr[2];
+
+#if EINSUM_USE_SSE1 && @float32@
+ __m128 a, b, value1_sse;
+
+ value1_sse = _mm_set_ps1(value1);
+#endif
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data_out)) {
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ a = _mm_mul_ps(_mm_load_ps(data0+@i@), value1_sse);
+ b = _mm_add_ps(a, _mm_load_ps(data_out+@i@));
+ _mm_store_ps(data_out+@i@, b);
+/**end repeat2**/
+ data0 += 16;
+ data_out += 16;
+ }
+ }
+#endif
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+#if EINSUM_USE_SSE1 && @float32@
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), value1_sse);
+ b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@));
+ _mm_storeu_ps(data_out+@i@, b);
+/**end repeat2**/
+#else
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ data_out[@i@] = @to@(@from@(data0[@i@])*
+ value1 +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+#endif
+ data0 += 16;
+ data_out += 16;
+ }
+
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+ data_out[@i@] = @to@(@from@(data0[@i@])*
+ value1 +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+}
+
+static void
+@name@_sum_of_products_contig_contig_outstride0_two(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
+ npy_@name@ *data1 = (npy_@name@ *)dataptr[1];
+ npy_@temp@ accum = 0;
+
+#if EINSUM_USE_SSE1 && @float32@
+ __m128 a, accum_sse = _mm_setzero_ps();
+#endif
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) {
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@));
+ accum_sse = _mm_add_ps(accum_sse, a);
+/**end repeat2**/
+ data0 += 16;
+ data1 += 16;
+ }
+ }
+#endif
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+#if EINSUM_USE_SSE1 && @float32@
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@));
+ accum_sse = _mm_add_ps(accum_sse, a);
+/**end repeat2**/
+#else
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ accum += @from@(data0[@i@]) * @from@(data1[@i@]);
+/**end repeat2**/
+#endif
+ data0 += 16;
+ data1 += 16;
+ }
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Add the four SSE values and put in accum */
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ _mm_store_ss(&accum, accum_sse);
+#endif
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ *(npy_@name@ *)dataptr[2] += @to@(accum);
+ return;
+ }
+ accum += @from@(data0[@i@]) * @from@(data1[@i@]);
+/**end repeat2**/
+
+ *(npy_@name@ *)dataptr[2] += @to@(accum);
+}
+
+static void
+@name@_sum_of_products_stride0_contig_outstride0_two(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@temp@ value0 = @from@(*(npy_@name@ *)dataptr[0]);
+ npy_@name@ *data1 = (npy_@name@ *)dataptr[1];
+ npy_@temp@ accum = 0;
+
+#if EINSUM_USE_SSE1 && @float32@
+ __m128 a, accum_sse = _mm_setzero_ps();
+#endif
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data1)) {
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data1+@i@));
+/**end repeat2**/
+ data1 += 16;
+ }
+ }
+#endif
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+#if EINSUM_USE_SSE1 && @float32@
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data1+@i@));
+/**end repeat2**/
+#else
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ accum += @from@(data1[@i@]);
+/**end repeat2**/
+#endif
+ data1 += 16;
+ }
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Add the four SSE values and put in accum */
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ _mm_store_ss(&accum, accum_sse);
+#endif
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ *(npy_@name@ *)dataptr[2] += @to@(value0 * accum);
+ return;
+ }
+ accum += @from@(data1[@i@]);
+/**end repeat2**/
+
+ *(npy_@name@ *)dataptr[2] += @to@(value0 * accum);
+}
+
+static void
+@name@_sum_of_products_contig_stride0_outstride0_two(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
+ npy_@temp@ value1 = @from@(*(npy_@name@ *)dataptr[1]);
+ npy_@temp@ accum = 0;
+
+#if EINSUM_USE_SSE1 && @float32@
+ __m128 a, accum_sse = _mm_setzero_ps();
+#endif
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data0)) {
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@));
+/**end repeat2**/
+ data0 += 16;
+ }
+ }
+#endif
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+#if EINSUM_USE_SSE1 && @float32@
+/**begin repeat2
+ * #i = 0, 4, 8, 12#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@));
+/**end repeat2**/
+#else
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ accum += @from@(data0[@i@]);
+/**end repeat2**/
+#endif
+ data0 += 16;
+ }
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Add the four SSE values and put in accum */
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ _mm_store_ss(&accum, accum_sse);
+#endif
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ *(npy_@name@ *)dataptr[2] += @to@(accum * value1);
+ return;
+ }
+ accum += @from@(data0[@i@]);
+/**end repeat2**/
+
+ *(npy_@name@ *)dataptr[2] += @to@(accum * value1);
+}
+
+#elif @nop@ == 3 && !@complex@
+
+static void
+@name@_sum_of_products_contig_three(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
+ npy_@name@ *data1 = (npy_@name@ *)dataptr[1];
+ npy_@name@ *data2 = (npy_@name@ *)dataptr[2];
+ npy_@name@ *data_out = (npy_@name@ *)dataptr[3];
+
+ /* Unroll the loop by 16 */
+ while (count >= 16) {
+ count -= 16;
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ data_out[@i@] = @to@(@from@(data0[@i@]) *
+ @from@(data1[@i@]) *
+ @from@(data2[@i@]) +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+ data0 += 16;
+ data1 += 16;
+ data_out += 16;
+ }
+
+ /* Finish off the loop */
+
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+ data_out[@i@] = @to@(@from@(data0[@i@]) *
+ @from@(data1[@i@]) *
+ @from@(data2[@i@]) +
+ @from@(data_out[@i@]));
+/**end repeat2**/
+}
+
+#else
+
+static void
+@name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr,
+ npy_intp *NPY_UNUSED(strides), npy_intp count)
+{
+ while (count--) {
+#if !@complex@
+ npy_@temp@ temp = @from@(*(npy_@name@ *)dataptr[0]);
+ int i;
+ for (i = 1; i < nop; ++i) {
+ temp *= @from@(*(npy_@name@ *)dataptr[i]);
+ }
+ *(npy_@name@ *)dataptr[nop] = @to@(temp +
+ @from@(*(npy_@name@ *)dataptr[i]));
+ for (i = 0; i <= nop; ++i) {
+ dataptr[i] += sizeof(npy_@name@);
+ }
+#else /* complex */
+# if @nop@ <= 3
+# define _SUMPROD_NOP @nop@
+# else
+# define _SUMPROD_NOP nop
+# endif
+ npy_@temp@ re, im, tmp;
+ re = ((npy_@temp@ *)dataptr[0])[0];
+ im = ((npy_@temp@ *)dataptr[0])[1];
+ int i;
+ for (i = 1; i < _SUMPROD_NOP; ++i) {
+ tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
+ im * ((npy_@temp@ *)dataptr[i])[1];
+ im = re * ((npy_@temp@ *)dataptr[i])[1] +
+ im * ((npy_@temp@ *)dataptr[i])[0];
+ re = tmp;
+ }
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[0] = re +
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[0];
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[1] = im +
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[1];
+
+ for (i = 0; i <= _SUMPROD_NOP; ++i) {
+ dataptr[i] += sizeof(npy_@name@);
+ }
+# undef _SUMPROD_NOP
+#endif
+ }
+}
+
+#endif
+
+static void
+@name@_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
+ npy_intp *strides, npy_intp count)
+{
+#if @complex@
+ npy_@temp@ accum_re = 0, accum_im = 0;
+#else
+ npy_@temp@ accum = 0;
+#endif
+
+#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@)
+ char *data0 = dataptr[0];
+ npy_intp stride0 = strides[0];
+#endif
+#if (@nop@ == 2 || @nop@ == 3) && !@complex@
+ char *data1 = dataptr[1];
+ npy_intp stride1 = strides[1];
+#endif
+#if (@nop@ == 3) && !@complex@
+ char *data2 = dataptr[2];
+ npy_intp stride2 = strides[2];
+#endif
+
+ while (count--) {
+#if !@complex@
+# if @nop@ == 1
+ accum += @from@(*(npy_@name@ *)data0);
+ data0 += stride0;
+# elif @nop@ == 2
+ accum += @from@(*(npy_@name@ *)data0) *
+ @from@(*(npy_@name@ *)data1);
+ data0 += stride0;
+ data1 += stride1;
+# elif @nop@ == 3
+ accum += @from@(*(npy_@name@ *)data0) *
+ @from@(*(npy_@name@ *)data1) *
+ @from@(*(npy_@name@ *)data2);
+ data0 += stride0;
+ data1 += stride1;
+ data2 += stride2;
+# else
+ npy_@temp@ temp = @from@(*(npy_@name@ *)dataptr[0]);
+ int i;
+ for (i = 1; i < nop; ++i) {
+ temp *= @from@(*(npy_@name@ *)dataptr[i]);
+ }
+ accum += temp;
+ for (i = 0; i < nop; ++i) {
+ dataptr[i] += strides[i];
+ }
+# endif
+#else /* complex */
+# if @nop@ == 1
+ accum_re += ((npy_@temp@ *)data0)[0];
+ accum_im += ((npy_@temp@ *)data0)[1];
+ data0 += stride0;
+# else
+# if @nop@ <= 3
+#define _SUMPROD_NOP @nop@
+# else
+#define _SUMPROD_NOP nop
+# endif
+ npy_@temp@ re, im, tmp;
+ re = ((npy_@temp@ *)dataptr[0])[0];
+ im = ((npy_@temp@ *)dataptr[0])[1];
+ int i;
+ for (i = 1; i < _SUMPROD_NOP; ++i) {
+ tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
+ im * ((npy_@temp@ *)dataptr[i])[1];
+ im = re * ((npy_@temp@ *)dataptr[i])[1] +
+ im * ((npy_@temp@ *)dataptr[i])[0];
+ re = tmp;
+ }
+ accum_re += re;
+ accum_im += im;
+ for (i = 0; i < _SUMPROD_NOP; ++i) {
+ dataptr[i] += strides[i];
+ }
+#undef _SUMPROD_NOP
+# endif
+#endif
+ }
+
+#if @complex@
+# if @nop@ <= 3
+ ((npy_@temp@ *)dataptr[@nop@])[0] += accum_re;
+ ((npy_@temp@ *)dataptr[@nop@])[1] += accum_im;
+# else
+ ((npy_@temp@ *)dataptr[nop])[0] += accum_re;
+ ((npy_@temp@ *)dataptr[nop])[1] += accum_im;
+# endif
+#else
+# if @nop@ <= 3
+ *((npy_@name@ *)dataptr[@nop@]) = @to@(accum +
+ @from@(*((npy_@name@ *)dataptr[@nop@])));
+# else
+ *((npy_@name@ *)dataptr[nop]) = @to@(accum +
+ @from@(*((npy_@name@ *)dataptr[nop])));
+# endif
+#endif
+
+}
+
/**end repeat1**/
/**end repeat**/
@@ -145,28 +909,45 @@ static void
bool_sum_of_products_@noplabel@(int nop, char **dataptr,
npy_intp *strides, npy_intp count)
{
+#if (@nop@ <= 3)
+ char *data0 = dataptr[0];
+ npy_intp stride0 = strides[0];
+#endif
+#if (@nop@ == 2 || @nop@ == 3)
+ char *data1 = dataptr[1];
+ npy_intp stride1 = strides[1];
+#endif
+#if (@nop@ == 3)
+ char *data2 = dataptr[2];
+ npy_intp stride2 = strides[2];
+#endif
+#if (@nop@ <= 3)
+ char *data_out = dataptr[@nop@];
+ npy_intp stride_out = strides[@nop@];
+#endif
+
while (count--) {
#if @nop@ == 1
- *(npy_bool *)dataptr[1] = *(npy_bool *)dataptr[0] ||
- *(npy_bool *)dataptr[1];
- dataptr[0] += strides[0];
- dataptr[1] += strides[1];
+ *(npy_bool *)data_out = *(npy_bool *)data0 ||
+ *(npy_bool *)data_out;
+ data0 += stride0;
+ data_out += stride_out;
#elif @nop@ == 2
- *(npy_bool *)dataptr[2] = (*(npy_bool *)dataptr[0] &&
- *(npy_bool *)dataptr[1]) ||
- *(npy_bool *)dataptr[2];
- dataptr[0] += strides[0];
- dataptr[1] += strides[1];
- dataptr[2] += strides[2];
+ *(npy_bool *)data_out = (*(npy_bool *)data0 &&
+ *(npy_bool *)data1) ||
+ *(npy_bool *)data_out;
+ data0 += stride0;
+ data1 += stride1;
+ data_out += stride_out;
#elif @nop@ == 3
- *(npy_bool *)dataptr[3] = (*(npy_bool *)dataptr[0] &&
- *(npy_bool *)dataptr[1] &&
- *(npy_bool *)dataptr[2]) ||
- *(npy_bool *)dataptr[3];
- dataptr[0] += strides[0];
- dataptr[1] += strides[1];
- dataptr[2] += strides[2];
- dataptr[3] += strides[3];
+ *(npy_bool *)data_out = (*(npy_bool *)data0 &&
+ *(npy_bool *)data1 &&
+ *(npy_bool *)data2) ||
+ *(npy_bool *)data_out;
+ data0 += stride0;
+ data1 += stride1;
+ data2 += stride2;
+ data_out += stride_out;
#else
npy_bool temp = *(npy_bool *)dataptr[0];
int i;
@@ -181,13 +962,377 @@ bool_sum_of_products_@noplabel@(int nop, char **dataptr,
}
}
+static void
+bool_sum_of_products_contig_@noplabel@(int nop, char **dataptr,
+ npy_intp *strides, npy_intp count)
+{
+#if (@nop@ <= 3)
+ char *data0 = dataptr[0];
+#endif
+#if (@nop@ == 2 || @nop@ == 3)
+ char *data1 = dataptr[1];
+#endif
+#if (@nop@ == 3)
+ char *data2 = dataptr[2];
+#endif
+#if (@nop@ <= 3)
+ char *data_out = dataptr[@nop@];
+#endif
+
+/* Unroll the loop by 16 for fixed-size nop */
+#if (@nop@ <= 3)
+ while (count >= 16) {
+ count -= 16;
+#else
+ while (count--) {
+#endif
+
+# if @nop@ == 1
+/**begin repeat1
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ *((npy_bool *)data_out + @i@) = (*((npy_bool *)data0 + @i@)) ||
+ (*((npy_bool *)data_out + @i@));
+/**end repeat1**/
+ data0 += 16*sizeof(npy_bool);
+ data_out += 16*sizeof(npy_bool);
+# elif @nop@ == 2
+/**begin repeat1
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ *((npy_bool *)data_out + @i@) =
+ ((*((npy_bool *)data0 + @i@)) &&
+ (*((npy_bool *)data1 + @i@))) ||
+ (*((npy_bool *)data_out + @i@));
+/**end repeat1**/
+ data0 += 16*sizeof(npy_bool);
+ data1 += 16*sizeof(npy_bool);
+ data_out += 16*sizeof(npy_bool);
+# elif @nop@ == 3
+/**begin repeat1
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ *((npy_bool *)data_out + @i@) =
+ ((*((npy_bool *)data0 + @i@)) &&
+ (*((npy_bool *)data1 + @i@)) &&
+ (*((npy_bool *)data2 + @i@))) ||
+ (*((npy_bool *)data_out + @i@));
+/**end repeat1**/
+ data0 += 16*sizeof(npy_bool);
+ data1 += 16*sizeof(npy_bool);
+ data2 += 16*sizeof(npy_bool);
+ data_out += 16*sizeof(npy_bool);
+# else
+ npy_bool temp = *(npy_bool *)dataptr[0];
+ int i;
+ for (i = 1; i < nop; ++i) {
+ temp = temp && *(npy_bool *)dataptr[i];
+ }
+ *(npy_bool *)dataptr[nop] = temp || *(npy_bool *)dataptr[i];
+ for (i = 0; i <= nop; ++i) {
+ dataptr[i] += sizeof(npy_bool);
+ }
+# endif
+ }
+
+ /* If the loop was unrolled, we need to finish it off */
+#if (@nop@ <= 3)
+# if @nop@ == 1
+/**begin repeat1
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+ *((npy_bool *)data_out + @i@) = (*((npy_bool *)data0 + @i@)) ||
+ (*((npy_bool *)data_out + @i@));
+/**end repeat1**/
+ data0 += 16*sizeof(npy_bool);
+ data_out += 16*sizeof(npy_bool);
+# elif @nop@ == 2
+/**begin repeat1
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+ *((npy_bool *)data_out + @i@) =
+ ((*((npy_bool *)data0 + @i@)) &&
+ (*((npy_bool *)data1 + @i@))) ||
+ (*((npy_bool *)data_out + @i@));
+/**end repeat1**/
+ data0 += 16*sizeof(npy_bool);
+ data1 += 16*sizeof(npy_bool);
+ data_out += 16*sizeof(npy_bool);
+# elif @nop@ == 3
+/**begin repeat1
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15#
+ */
+ if (count-- == 0) {
+ return;
+ }
+ *((npy_bool *)data_out + @i@) =
+ ((*((npy_bool *)data0 + @i@)) &&
+ (*((npy_bool *)data1 + @i@)) &&
+ (*((npy_bool *)data2 + @i@))) ||
+ (*((npy_bool *)data_out + @i@));
+/**end repeat1**/
+ data0 += 16*sizeof(npy_bool);
+ data1 += 16*sizeof(npy_bool);
+ data2 += 16*sizeof(npy_bool);
+ data_out += 16*sizeof(npy_bool);
+# endif
+#endif
+}
+
+static void
+bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
+ npy_intp *strides, npy_intp count)
+{
+ npy_bool accum = 0;
+
+#if (@nop@ <= 3)
+ char *data0 = dataptr[0];
+ npy_intp stride0 = strides[0];
+#endif
+#if (@nop@ == 2 || @nop@ == 3)
+ char *data1 = dataptr[1];
+ npy_intp stride1 = strides[1];
+#endif
+#if (@nop@ == 3)
+ char *data2 = dataptr[2];
+ npy_intp stride2 = strides[2];
+#endif
+
+ while (count--) {
+#if @nop@ == 1
+ accum = *(npy_bool *)data0 || accum;
+ data0 += stride0;
+#elif @nop@ == 2
+ accum = (*(npy_bool *)data0 && *(npy_bool *)data1) || accum;
+ data0 += stride0;
+ data1 += stride1;
+#elif @nop@ == 3
+ accum = (*(npy_bool *)data0 &&
+ *(npy_bool *)data1 &&
+ *(npy_bool *)data2) || accum;
+ data0 += stride0;
+ data1 += stride1;
+ data2 += stride2;
+#else
+ npy_bool temp = *(npy_bool *)dataptr[0];
+ int i;
+ for (i = 1; i < nop; ++i) {
+ temp = temp && *(npy_bool *)dataptr[i];
+ }
+ accum = temp || accum;
+ for (i = 0; i <= nop; ++i) {
+ dataptr[i] += strides[i];
+ }
+#endif
+ }
+
+# if @nop@ <= 3
+ *((npy_bool *)dataptr[@nop@]) = accum || *((npy_bool *)dataptr[@nop@]);
+# else
+ *((npy_bool *)dataptr[nop]) = accum || *((npy_bool *)dataptr[nop]);
+# endif
+}
+
/**end repeat**/
typedef void (*sum_of_products_fn)(int, char **, npy_intp *, npy_intp);
static sum_of_products_fn
-get_sum_of_products_function(int nop, int type_num)
+get_sum_of_products_function(int nop, int type_num,
+ npy_intp itemsize, npy_intp *fixed_strides)
{
+ int iop;
+
+ /* nop of 2 has more specializations */
+ if (nop == 2) {
+ if (fixed_strides[0] == itemsize) {
+ if (fixed_strides[1] == itemsize) {
+ if (fixed_strides[2] == itemsize) {
+ /* contig, contig, contig */
+ switch (type_num) {
+/**begin repeat
+ * #name = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble#
+ * #NAME = BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ return &@name@_sum_of_products_contig_two;
+/**end repeat**/
+ }
+ }
+ else if (fixed_strides[2] == 0) {
+ /* contig, contig, stride0 */
+ switch (type_num) {
+/**begin repeat
+ * #name = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble#
+ * #NAME = BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ return &@name@_sum_of_products_contig_contig_outstride0_two;
+/**end repeat**/
+ }
+ }
+ }
+ else if (fixed_strides[1] == 0) {
+ if (fixed_strides[2] == itemsize) {
+ /* contig, stride0, contig */
+ switch (type_num) {
+/**begin repeat
+ * #name = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble#
+ * #NAME = BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ return &@name@_sum_of_products_contig_stride0_outcontig_two;
+/**end repeat**/
+ }
+ }
+ else if (fixed_strides[2] == 0) {
+ /* contig, stride0, stride0 */
+ switch (type_num) {
+/**begin repeat
+ * #name = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble#
+ * #NAME = BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ return &@name@_sum_of_products_contig_stride0_outstride0_two;
+/**end repeat**/
+ }
+ }
+ }
+ }
+ else if (fixed_strides[0] == 0) {
+ if (fixed_strides[1] == itemsize) {
+ if (fixed_strides[2] == itemsize) {
+ /* stride0, contig, contig */
+ switch (type_num) {
+/**begin repeat
+ * #name = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble#
+ * #NAME = BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ return &@name@_sum_of_products_stride0_contig_outcontig_two;
+/**end repeat**/
+ }
+ }
+ else if (fixed_strides[2] == 0) {
+ /* stride0, contig, stride0 */
+ switch (type_num) {
+/**begin repeat
+ * #name = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble#
+ * #NAME = BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ return &@name@_sum_of_products_stride0_contig_outstride0_two;
+/**end repeat**/
+ }
+ }
+ }
+ }
+ }
+
+ /* Inner loop with an output stride of 0 */
+ if (fixed_strides[nop] == 0) {
+ switch (type_num) {
+/**begin repeat
+ * #name = bool,
+ * byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble,
+ * cfloat, cdouble, clongdouble#
+ * #NAME = BOOL,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ switch (nop) {
+/**begin repeat1
+ * #nop = 1, 2, 3, 1000#
+ * #noplabel = one, two, three, any#
+ */
+#if @nop@ <= 3
+ case @nop@:
+#else
+ default:
+#endif
+ return &@name@_sum_of_products_outstride0_@noplabel@;
+/**end repeat1**/
+ }
+/**end repeat**/
+ }
+ }
+
+ /* Check for all contiguous */
+ for (iop = 0; iop < nop; ++iop) {
+ if (fixed_strides[iop] != itemsize) {
+ break;
+ }
+ }
+
+ /* Contiguous loop */
+ if (iop == nop) {
+ switch (type_num) {
+/**begin repeat
+ * #name = bool,
+ * byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble,
+ * cfloat, cdouble, clongdouble#
+ * #NAME = BOOL,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ switch (nop) {
+/**begin repeat1
+ * #nop = 1, 2, 3, 1000#
+ * #noplabel = one, two, three, any#
+ */
+#if @nop@ <= 3
+ case @nop@:
+#else
+ default:
+#endif
+ return &@name@_sum_of_products_contig_@noplabel@;
+/**end repeat1**/
+ }
+/**end repeat**/
+ }
+ }
+
+ /* Regular inner loop */
switch (type_num) {
/**begin repeat
* #name = bool,
@@ -209,11 +1354,10 @@ get_sum_of_products_function(int nop, int type_num)
*/
#if @nop@ <= 3
case @nop@:
- return &@name@_sum_of_products_@noplabel@;
#else
default:
- return &@name@_sum_of_products_@noplabel@;
#endif
+ return &@name@_sum_of_products_@noplabel@;
/**end repeat1**/
}
/**end repeat**/
@@ -913,6 +2057,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
NpyIter *iter;
sum_of_products_fn sop;
+ npy_intp fixed_strides[NPY_MAXARGS];
/* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */
if (nop >= NPY_MAXARGS) {
@@ -1156,7 +2301,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
/* Allocate the iterator */
iter = NpyIter_MultiNew(nop+1, op, NPY_ITER_NO_INNER_ITERATION|
- ((dtype == NULL) ? 0 : NPY_ITER_COMMON_DTYPE)|
+ ((dtype != NULL) ? 0 : NPY_ITER_COMMON_DTYPE)|
NPY_ITER_BUFFERED|
NPY_ITER_DELAY_BUFALLOC|
NPY_ITER_GROWINNER|
@@ -1176,8 +2321,20 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
PyArray_FillWithZero(ret);
NpyIter_Reset(iter, NULL);
+ NpyIter_GetInnerFixedStrideArray(iter, fixed_strides);
sop = get_sum_of_products_function(nop,
- NpyIter_GetDescrArray(iter)[0]->type_num);
+ NpyIter_GetDescrArray(iter)[0]->type_num,
+ NpyIter_GetDescrArray(iter)[0]->elsize,
+ fixed_strides);
+
+ #if 0
+ NpyIter_DebugPrint(iter);
+ printf("fixed strides:\n");
+ for (iop = 0; iop <= nop; ++iop) {
+ printf("%ld ", fixed_strides[iop]);
+ }
+ printf("\n");
+ #endif
/* Finally, the main loop */
if (sop == NULL) {
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src
index a1464ff92..653a8e829 100644
--- a/numpy/core/src/multiarray/new_iterator.c.src
+++ b/numpy/core/src/multiarray/new_iterator.c.src
@@ -113,7 +113,7 @@ typedef struct NpyIter_BD NpyIter_BufferData;
((NPY_SIZEOF_INTP)*(niter+1))
#define NIT_BASEOFFSETS_SIZEOF(itflags, ndim, niter) \
((NPY_SIZEOF_INTP)*(niter+1))
-#define NIT_OBJECTS_SIZEOF(itflags, ndim, niter) \
+#define NIT_OPERANDS_SIZEOF(itflags, ndim, niter) \
((NPY_SIZEOF_INTP)*(niter))
#define NIT_OPITFLAGS_SIZEOF(itflags, ndim, niter) \
(NPY_INTP_ALIGNED(niter))
@@ -132,12 +132,12 @@ typedef struct NpyIter_BD NpyIter_BufferData;
#define NIT_BASEOFFSETS_OFFSET(itflags, ndim, niter) \
(NIT_RESETDATAPTR_OFFSET(itflags, ndim, niter) + \
NIT_RESETDATAPTR_SIZEOF(itflags, ndim, niter))
-#define NIT_OBJECTS_OFFSET(itflags, ndim, niter) \
+#define NIT_OPERANDS_OFFSET(itflags, ndim, niter) \
(NIT_BASEOFFSETS_OFFSET(itflags, ndim, niter) + \
NIT_BASEOFFSETS_SIZEOF(itflags, ndim, niter))
#define NIT_OPITFLAGS_OFFSET(itflags, ndim, niter) \
- (NIT_OBJECTS_OFFSET(itflags, ndim, niter) + \
- NIT_OBJECTS_SIZEOF(itflags, ndim, niter))
+ (NIT_OPERANDS_OFFSET(itflags, ndim, niter) + \
+ NIT_OPERANDS_SIZEOF(itflags, ndim, niter))
#define NIT_BUFFERDATA_OFFSET(itflags, ndim, niter) \
(NIT_OPITFLAGS_OFFSET(itflags, ndim, niter) + \
NIT_OPITFLAGS_SIZEOF(itflags, ndim, niter))
@@ -168,8 +168,8 @@ typedef struct NpyIter_BD NpyIter_BufferData;
&(iter)->iter_flexdata + NIT_RESETDATAPTR_OFFSET(itflags, ndim, niter)))
#define NIT_BASEOFFSETS(iter) ((npy_intp *)( \
&(iter)->iter_flexdata + NIT_BASEOFFSETS_OFFSET(itflags, ndim, niter)))
-#define NIT_OBJECTS(iter) ((PyArrayObject **)( \
- &(iter)->iter_flexdata + NIT_OBJECTS_OFFSET(itflags, ndim, niter)))
+#define NIT_OPERANDS(iter) ((PyArrayObject **)( \
+ &(iter)->iter_flexdata + NIT_OPERANDS_OFFSET(itflags, ndim, niter)))
#define NIT_OPITFLAGS(iter) ( \
&(iter)->iter_flexdata + NIT_OPITFLAGS_OFFSET(itflags, ndim, niter))
#define NIT_BUFFERDATA(iter) ((NpyIter_BufferData *)( \
@@ -399,7 +399,7 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
NIT_ITERINDEX(iter) = 0;
memset(NIT_BASEOFFSETS(iter), 0, (niter+1)*NPY_SIZEOF_INTP);
- op = NIT_OBJECTS(iter);
+ op = NIT_OPERANDS(iter);
op_dtype = NIT_DTYPES(iter);
op_itflags = NIT_OPITFLAGS(iter);
op_dataptr = NIT_RESETDATAPTR(iter);
@@ -519,6 +519,9 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
PyArray_Descr *dtype;
int only_inputs = !(flags&NPY_ITER_COMMON_DTYPE);
+ op = NIT_OPERANDS(iter);
+ op_dtype = NIT_DTYPES(iter);
+
dtype = npyiter_get_common_dtype(niter, op,
op_itflags, op_dtype,
op_request_dtypes,
@@ -529,21 +532,21 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
return NULL;
}
if (flags&NPY_ITER_COMMON_DTYPE) {
+ NPY_IT_DBG_PRINTF("Iterator: Replacing all data types\n");
/* Replace all the data types */
for (iiter = 0; iiter < niter; ++iiter) {
Py_XDECREF(op_dtype[iiter]);
Py_INCREF(dtype);
op_dtype[iiter] = dtype;
- NIT_DTYPES(iter)[iiter] = dtype;
}
}
else {
+ NPY_IT_DBG_PRINTF("Iterator: Setting unset output data types\n");
/* Replace the NULL data types */
for (iiter = 0; iiter < niter; ++iiter) {
if (op_dtype[iiter] == NULL) {
Py_INCREF(dtype);
op_dtype[iiter] = dtype;
- NIT_DTYPES(iter)[iiter] = dtype;
}
}
}
@@ -584,7 +587,7 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
*/
itflags = NIT_ITFLAGS(iter);
ndim = NIT_NDIM(iter);
- op = NIT_OBJECTS(iter);
+ op = NIT_OPERANDS(iter);
op_dtype = NIT_DTYPES(iter);
op_itflags = NIT_OPITFLAGS(iter);
op_dataptr = NIT_RESETDATAPTR(iter);
@@ -697,7 +700,7 @@ NpyIter_Copy(NpyIter *iter)
memcpy(newiter, iter, size);
/* Take ownership of references to the operands and dtypes */
- objects = NIT_OBJECTS(newiter);
+ objects = NIT_OPERANDS(newiter);
dtypes = NIT_DTYPES(newiter);
for (iiter = 0; iiter < niter; ++iiter) {
Py_INCREF(objects[iiter]);
@@ -787,7 +790,7 @@ NpyIter_Deallocate(NpyIter *iter)
npy_intp iiter, niter = NIT_NITER(iter);
PyArray_Descr **dtype = NIT_DTYPES(iter);
- PyArrayObject **object = NIT_OBJECTS(iter);
+ PyArrayObject **object = NIT_OPERANDS(iter);
/* Deallocate any buffers and buffering data */
if (itflags&NPY_ITFLAG_BUFFER) {
@@ -2018,7 +2021,7 @@ NpyIter_GetOperandArray(NpyIter *iter)
npy_intp ndim = NIT_NDIM(iter);
npy_intp niter = NIT_NITER(iter);
- return NIT_OBJECTS(iter);
+ return NIT_OPERANDS(iter);
}
/*NUMPY_API
@@ -2052,7 +2055,7 @@ NpyIter_GetIterView(NpyIter *iter, npy_intp i)
return NULL;
}
- obj = NIT_OBJECTS(iter)[i];
+ obj = NIT_OPERANDS(iter)[i];
dtype = PyArray_DESCR(obj);
writeable = NIT_OPITFLAGS(iter)[i]&NPY_OP_ITFLAG_WRITE;
dataptr = NIT_RESETDATAPTR(iter)[i];
@@ -2161,6 +2164,65 @@ NpyIter_GetInnerStrideArray(NpyIter *iter)
}
/*NUMPY_API
+ * Get an array of strides which are fixed. Any strides which may
+ * change during iteration receive the value NPY_MAX_INTP. Once
+ * the iterator is ready to iterate, call this to get the strides
+ * which will always be fixed in the inner loop, then choose optimized
+ * inner loop functions which take advantage of those fixed strides.
+ *
+ * This function may be safely called without holding the Python GIL.
+ */
+NPY_NO_EXPORT void
+NpyIter_GetInnerFixedStrideArray(NpyIter *iter, npy_intp *out_strides)
+{
+ npy_uint32 itflags = NIT_ITFLAGS(iter);
+ npy_intp ndim = NIT_NDIM(iter);
+ npy_intp iiter, niter = NIT_NITER(iter);
+
+ NpyIter_AxisData *axisdata = NIT_AXISDATA(iter);
+
+ if (itflags&NPY_ITFLAG_BUFFER) {
+ NpyIter_BufferData *data = NIT_BUFFERDATA(iter);
+ char *op_itflags = NIT_OPITFLAGS(iter);
+ npy_intp stride, *strides = NBF_STRIDES(data),
+ *ad_strides = NAD_STRIDES(axisdata);
+ PyArray_Descr **dtypes = NIT_DTYPES(iter);
+
+ for (iiter = 0; iiter < niter; ++iiter) {
+ stride = strides[iiter];
+ /* Operands which are always/never buffered have fixed strides */
+ if (op_itflags[iiter]&
+ (NPY_OP_ITFLAG_CAST|NPY_OP_ITFLAG_BUFNEVER)) {
+ out_strides[iiter] = stride;
+ }
+ /* Reductions in the inner loop have fixed strides */
+ else if (stride == 0 && (itflags&NPY_ITFLAG_REDUCE)) {
+ out_strides[iiter] = stride;
+ }
+ /*
+ * Inner loop contiguous array means its stride won't change when
+ * switching between buffering and not buffering
+ */
+ else if (ad_strides[iiter] == dtypes[iiter]->elsize) {
+ out_strides[iiter] = ad_strides[iiter];
+ }
+ /*
+ * Otherwise the strides can change if the operand is sometimes
+ * buffered, sometimes not.
+ */
+ else {
+ out_strides[iiter] = NPY_MAX_INTP;
+ }
+ }
+ }
+ else {
+ /* If there's no buffering, the strides are always fixed */
+ memcpy(out_strides, NAD_STRIDES(axisdata), niter*NPY_SIZEOF_INTP);
+ }
+}
+
+
+/*NUMPY_API
* Get a pointer to the size of the inner loop (when HasInnerLoop is false)
*
* This function may be safely called without holding the Python GIL.
@@ -2685,6 +2747,20 @@ npyiter_check_casting(npy_intp niter, PyArrayObject **op,
npy_intp iiter;
for(iiter = 0; iiter < niter; ++iiter) {
+ NPY_IT_DBG_PRINTF("Iterator: Checking casting for operand %d\n",
+ (int)iiter);
+#if NPY_IT_DBG_TRACING
+ printf("op: ");
+ if (op[iiter] != NULL) {
+ PyObject_Print((PyObject *)PyArray_DESCR(op[iiter]), stdout, 0);
+ }
+ else {
+ printf("<null>");
+ }
+ printf(", iter: ");
+ PyObject_Print((PyObject *)op_dtype[iiter], stdout, 0);
+ printf("\n");
+#endif
/* If the types aren't equivalent, a cast is necessary */
if (op[iiter] != NULL && !PyArray_EquivTypes(PyArray_DESCR(op[iiter]),
op_dtype[iiter])) {
@@ -2795,7 +2871,7 @@ npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, char *op_itflags,
char *odataptr;
NpyIter_AxisData *axisdata0, *axisdata;
npy_intp sizeof_axisdata;
- PyArrayObject **op = NIT_OBJECTS(iter);
+ PyArrayObject **op = NIT_OPERANDS(iter);
axisdata0 = NIT_AXISDATA(iter);
sizeof_axisdata = NIT_AXISDATA_SIZEOF(itflags, ndim, niter);
@@ -3331,7 +3407,7 @@ npyiter_apply_forced_iteration_order(NpyIter *iter, NPY_ORDER order)
NIT_ITFLAGS(iter) |= NPY_ITFLAG_FORCEDORDER;
/* Only need to actually do something if there is more than 1 dim */
if (ndim > 1) {
- PyArrayObject **op = NIT_OBJECTS(iter);
+ PyArrayObject **op = NIT_OPERANDS(iter);
int forder = 1;
/* Check that all the array inputs are fortran order */
@@ -3975,7 +4051,7 @@ npyiter_allocate_arrays(NpyIter *iter,
npy_intp iiter, niter = NIT_NITER(iter);
NpyIter_BufferData *bufferdata = NIT_BUFFERDATA(iter);
- PyArrayObject **op = NIT_OBJECTS(iter);
+ PyArrayObject **op = NIT_OPERANDS(iter);
for (iiter = 0; iiter < niter; ++iiter) {
if (op[iiter] == NULL) {
@@ -4216,6 +4292,8 @@ npyiter_get_common_dtype(npy_intp niter, PyArrayObject **op,
PyArrayObject *arrs[NPY_MAXARGS];
PyArray_Descr *dtypes[NPY_MAXARGS];
+ NPY_IT_DBG_PRINTF("Iterator: Getting a common data type from operands\n");
+
for (iiter = 0; iiter < niter; ++iiter) {
if (op_dtype[iiter] != NULL &&
(!only_inputs || (op_itflags[iiter]&NPY_OP_ITFLAG_READ))) {
@@ -4246,7 +4324,7 @@ npyiter_allocate_transfer_functions(NpyIter *iter)
char *op_itflags = NIT_OPITFLAGS(iter);
NpyIter_BufferData *bufferdata = NIT_BUFFERDATA(iter);
NpyIter_AxisData *axisdata = NIT_AXISDATA(iter);
- PyArrayObject **op = NIT_OBJECTS(iter);
+ PyArrayObject **op = NIT_OPERANDS(iter);
PyArray_Descr **op_dtype = NIT_DTYPES(iter);
npy_intp *strides = NAD_STRIDES(axisdata), op_stride;
PyArray_StridedTransferFn **readtransferfn = NBF_READTRANSFERFN(bufferdata),
@@ -4614,7 +4692,7 @@ npyiter_copy_to_buffers(NpyIter *iter)
NpyIter_AxisData *axisdata = NIT_AXISDATA(iter);
PyArray_Descr **dtypes = NIT_DTYPES(iter);
- PyArrayObject **operands = NIT_OBJECTS(iter);
+ PyArrayObject **operands = NIT_OPERANDS(iter);
npy_intp *strides = NBF_STRIDES(bufferdata),
*ad_strides = NAD_STRIDES(axisdata);
char **ptrs = NBF_PTRS(bufferdata), **ad_ptrs = NAD_PTRS(axisdata);
@@ -4929,14 +5007,14 @@ NpyIter_DebugPrint(NpyIter *iter)
}
printf("Operands: ");
for (iiter = 0; iiter < niter; ++iiter) {
- printf("%p ", NIT_OBJECTS(iter)[iiter]);
+ printf("%p ", NIT_OPERANDS(iter)[iiter]);
}
printf("\n");
printf("Operand DTypes: ");
for (iiter = 0; iiter < niter; ++iiter) {
PyArray_Descr *dtype;
- if (NIT_OBJECTS(iter)[iiter] != NULL) {
- dtype = PyArray_DESCR(NIT_OBJECTS(iter)[iiter]);
+ if (NIT_OPERANDS(iter)[iiter] != NULL) {
+ dtype = PyArray_DESCR(NIT_OPERANDS(iter)[iiter]);
if (dtype != NULL)
PyObject_Print((PyObject *)dtype, stdout, 0);
else
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 3fe9ebd23..8bd7d45ae 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -283,72 +283,139 @@ class TestEinSum(TestCase):
assert_(b.base is a)
assert_equal(b, a.swapaxes(0,1))
- def test_einsum_sums(self):
+ def check_einsum_sums(self, dtype):
# sum(a, axis=-1)
- a = np.arange(10)
+ a = np.arange(10, dtype=dtype)
assert_equal(np.einsum("i->", a), np.sum(a, axis=-1))
- a = np.arange(24).reshape(2,3,4)
+ a = np.arange(24, dtype=dtype).reshape(2,3,4)
assert_equal(np.einsum("i->", a), np.sum(a, axis=-1))
# sum(a, axis=0)
- a = np.arange(10)
+ a = np.arange(10, dtype=dtype)
assert_equal(np.einsum("i...->", a), np.sum(a, axis=0))
- a = np.arange(24).reshape(2,3,4)
+ a = np.arange(24, dtype=dtype).reshape(2,3,4)
assert_equal(np.einsum("i...->", a), np.sum(a, axis=0))
# trace(a)
- a = np.arange(25).reshape(5,5)
+ a = np.arange(25, dtype=dtype).reshape(5,5)
assert_equal(np.einsum("ii", a), np.trace(a))
# multiply(a, b)
- a = np.arange(12).reshape(3,4)
- b = np.arange(24).reshape(2,3,4)
+ a = np.arange(12, dtype=dtype).reshape(3,4)
+ b = np.arange(24, dtype=dtype).reshape(2,3,4)
assert_equal(np.einsum(",", a, b), np.multiply(a, b))
# inner(a,b)
- a = np.arange(24).reshape(2,3,4)
- b = np.arange(4)
+ a = np.arange(24, dtype=dtype).reshape(2,3,4)
+ b = np.arange(4, dtype=dtype)
assert_equal(np.einsum("i,i", a, b), np.inner(a, b))
- a = np.arange(24).reshape(2,3,4)
- b = np.arange(2)
+ a = np.arange(24, dtype=dtype).reshape(2,3,4)
+ b = np.arange(2, dtype=dtype)
assert_equal(np.einsum("i...,i...", a, b), np.inner(a.T, b.T).T)
# outer(a,b)
- a = np.arange(3)+1
- b = np.arange(4)+1
+ a = np.arange(3, dtype=dtype)+1
+ b = np.arange(4, dtype=dtype)+1
assert_equal(np.einsum("i,j", a, b), np.outer(a, b))
# matvec(a,b) / a.dot(b) where a is matrix, b is vector
- a = np.arange(20).reshape(4,5)
- b = np.arange(5)
+ a = np.arange(20, dtype=dtype).reshape(4,5)
+ b = np.arange(5, dtype=dtype)
assert_equal(np.einsum("ij,j", a, b), np.dot(a, b))
- a = np.arange(20).reshape(4,5)
- b = np.arange(5)
+ a = np.arange(20, dtype=dtype).reshape(4,5)
+ b = np.arange(5, dtype=dtype)
assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T))
# matmat(a,b) / a.dot(b) where a is matrix, b is matrix
- a = np.arange(20).reshape(4,5)
- b = np.arange(30).reshape(5,6)
+ a = np.arange(20, dtype=dtype).reshape(4,5)
+ b = np.arange(30, dtype=dtype).reshape(5,6)
assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b))
# tensordot(a, b)
- a = np.arange(60.).reshape(3,4,5)
- b = np.arange(24.).reshape(4,3,2)
- assert_equal(np.einsum("ijk,jil->kl", a, b),
- np.tensordot(a,b, axes=([1,0],[0,1])))
+ if np.dtype(dtype) != np.dtype('f2'):
+ a = np.arange(60, dtype=dtype).reshape(3,4,5)
+ b = np.arange(24, dtype=dtype).reshape(4,3,2)
+ assert_equal(np.einsum("ijk,jil->kl", a, b),
+ np.tensordot(a,b, axes=([1,0],[0,1])))
# logical_and(logical_and(a!=0, b!=0), c!=0)
- a = np.array([1, 3, -2, 0, 12, 13, 0, 1])
- b = np.array([0, 3.5, 0., -2, 0, 1, 3, 12])
+ a = np.array([1, 3, -2, 0, 12, 13, 0, 1], dtype=dtype)
+ b = np.array([0, 3.5, 0., -2, 0, 1, 3, 12], dtype=dtype)
c = np.array([True,True,False,True,True,False,True,True])
assert_equal(np.einsum("i,i,i->i", a, b, c,
dtype='?', casting='unsafe'),
logical_and(logical_and(a!=0, b!=0), c!=0))
+ a = np.arange(9, dtype=dtype)
+ assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a))
+ assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a))
+
+ # Various stride0, contiguous, and SSE aligned variants
+ a = np.arange(64, dtype=dtype)
+ if np.dtype(dtype).itemsize > 1:
+ assert_equal(np.einsum(",",a,a), np.multiply(a,a))
+ assert_equal(np.einsum("i,i", a, a), np.dot(a,a))
+ assert_equal(np.einsum("i,->i", a, 2), 2*a)
+ assert_equal(np.einsum(",i->i", 2, a), 2*a)
+ assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a))
+ assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a))
+
+ assert_equal(np.einsum(",",a[1:],a[:-1]), np.multiply(a[1:],a[:-1]))
+ assert_equal(np.einsum("i,i", a[1:], a[:-1]), np.dot(a[1:],a[:-1]))
+ assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:])
+ assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:])
+ assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:]))
+ assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:]))
+
+
+ def test_einsum_sums_int8(self):
+ self.check_einsum_sums('i1');
+
+ def test_einsum_sums_uint8(self):
+ self.check_einsum_sums('u1');
+
+ def test_einsum_sums_int16(self):
+ self.check_einsum_sums('i2');
+
+ def test_einsum_sums_uint16(self):
+ self.check_einsum_sums('u2');
+
+ def test_einsum_sums_int32(self):
+ self.check_einsum_sums('i4');
+
+ def test_einsum_sums_uint32(self):
+ self.check_einsum_sums('u4');
+
+ def test_einsum_sums_int64(self):
+ self.check_einsum_sums('i8');
+
+ def test_einsum_sums_uint64(self):
+ self.check_einsum_sums('u8');
+
+ def test_einsum_sums_float16(self):
+ self.check_einsum_sums('f2');
+
+ def test_einsum_sums_float32(self):
+ self.check_einsum_sums('f4');
+
+ def test_einsum_sums_float64(self):
+ self.check_einsum_sums('f8');
+
+ def test_einsum_sums_longdouble(self):
+ self.check_einsum_sums(np.longdouble);
+
+ def test_einsum_sums_cfloat64(self):
+ self.check_einsum_sums('c8');
+
+ def test_einsum_sums_cfloat128(self):
+ self.check_einsum_sums('c16');
+
+ def test_einsum_sums_clongdouble(self):
+ self.check_einsum_sums(np.clongdouble);
class TestNonarrayArgs(TestCase):
# check that non-array arguments to functions wrap them in arrays