summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/einsum.c.src259
-rw-r--r--numpy/core/src/multiarray/new_iterator.c.src2
-rw-r--r--numpy/core/tests/test_numeric.py74
3 files changed, 319 insertions, 16 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 97905ca54..4ebd3aa82 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -14,6 +14,7 @@
#define _MULTIARRAYMODULE
#include <numpy/ndarrayobject.h>
+#include <numpy/halffloat.h>
#include <ctype.h>
@@ -23,6 +24,204 @@ typedef enum {
BROADCAST_MIDDLE
} EINSUM_BROADCAST;
+/**begin repeat
+ * #name = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble,
+ * cfloat, cdouble, clongdouble#
+ * #temp = byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * float, float, double, longdouble,
+ * float, double, longdouble#
+ * #to = ,,,,,
+ * ,,,,,
+ * npy_float_to_half,,,,
+ * ,,#
+ * #from = ,,,,,
+ * ,,,,,
+ * npy_half_to_float,,,,
+ * ,,#
+ * #complex = 0*5,
+ * 0*5,
+ * 0*4,
+ * 1*3#
+ */
+
+/**begin repeat1
+ * #nop = 1, 2, 3, 1000#
+ * #noplabel = one, two, three, any#
+ */
+static void
+@name@_sum_of_products_@noplabel@(int nop, char **dataptr,
+ npy_intp *strides, npy_intp count)
+{
+ while (count--) {
+#if !@complex@
+# if @nop@ == 1
+ *(npy_@name@ *)dataptr[1] = @to@(@from@(*(npy_@name@ *)dataptr[0]) +
+ @from@(*(npy_@name@ *)dataptr[1]));
+ dataptr[0] += strides[0];
+ dataptr[1] += strides[1];
+# elif @nop@ == 2
+ *(npy_@name@ *)dataptr[2] = @to@(@from@(*(npy_@name@ *)dataptr[0]) *
+ @from@(*(npy_@name@ *)dataptr[1]) +
+ @from@(*(npy_@name@ *)dataptr[2]));
+ dataptr[0] += strides[0];
+ dataptr[1] += strides[1];
+ dataptr[2] += strides[2];
+# elif @nop@ == 3
+ *(npy_@name@ *)dataptr[3] = @to@(@from@(*(npy_@name@ *)dataptr[0]) *
+ @from@(*(npy_@name@ *)dataptr[1]) *
+ @from@(*(npy_@name@ *)dataptr[2]) +
+ @from@(*(npy_@name@ *)dataptr[3]));
+ dataptr[0] += strides[0];
+ dataptr[1] += strides[1];
+ dataptr[2] += strides[2];
+ dataptr[3] += strides[3];
+# else
+ npy_@temp@ temp = @from@(*(npy_@name@ *)dataptr[0]);
+ int i;
+ for (i = 1; i < nop; ++i) {
+ temp *= @from@(*(npy_@name@ *)dataptr[i]);
+ }
+ *(npy_@name@ *)dataptr[nop] = @to@(temp +
+ @from@(*(npy_@name@ *)dataptr[i]));
+ for (i = 0; i <= nop; ++i) {
+ dataptr[i] += strides[i];
+ }
+# endif
+#else /* complex */
+# if @nop@ == 1
+ ((npy_@temp@ *)dataptr[1])[0] = ((npy_@temp@ *)dataptr[0])[0] +
+ ((npy_@temp@ *)dataptr[1])[0];
+ ((npy_@temp@ *)dataptr[1])[1] = ((npy_@temp@ *)dataptr[0])[1] +
+ ((npy_@temp@ *)dataptr[1])[1];
+ dataptr[0] += strides[0];
+ dataptr[1] += strides[1];
+# else
+# if @nop@ <= 3
+#define _SUMPROD_NOP @nop@
+# else
+#define _SUMPROD_NOP nop
+# endif
+ npy_@temp@ re, im, tmp;
+ re = ((npy_@temp@ *)dataptr[0])[0];
+ im = ((npy_@temp@ *)dataptr[0])[1];
+ int i;
+ for (i = 1; i <= _SUMPROD_NOP; ++i) {
+ tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
+ im * ((npy_@temp@ *)dataptr[i])[1];
+ im = re * ((npy_@temp@ *)dataptr[i])[1] +
+ im * ((npy_@temp@ *)dataptr[i])[0];
+ re = tmp;
+ }
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[0] = re +
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[0];
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[1] = im +
+ ((npy_@temp@ *)dataptr[_SUMPROD_NOP])[1];
+
+ for (i = 0; i <= _SUMPROD_NOP; ++i) {
+ dataptr[i] += strides[i];
+ }
+#undef _SUMPROD_NOP
+# endif
+#endif
+ }
+}
+
+/**end repeat1**/
+
+/**end repeat**/
+
+
+/* Do OR of ANDs for the boolean type */
+
+/**begin repeat
+ * #nop = 1, 2, 3, 1000#
+ * #noplabel = one, two, three, any#
+ */
+
+static void
+bool_sum_of_products_@noplabel@(int nop, char **dataptr,
+ npy_intp *strides, npy_intp count)
+{
+ while (count--) {
+#if @nop@ == 1
+ *(npy_bool *)dataptr[1] = *(npy_bool *)dataptr[0] ||
+ *(npy_bool *)dataptr[1];
+ dataptr[0] += strides[0];
+ dataptr[1] += strides[1];
+#elif @nop@ == 2
+ *(npy_bool *)dataptr[2] = (*(npy_bool *)dataptr[0] &&
+ *(npy_bool *)dataptr[1]) ||
+ *(npy_bool *)dataptr[2];
+ dataptr[0] += strides[0];
+ dataptr[1] += strides[1];
+ dataptr[2] += strides[2];
+#elif @nop@ == 3
+ *(npy_bool *)dataptr[3] = (*(npy_bool *)dataptr[0] &&
+ *(npy_bool *)dataptr[1] &&
+ *(npy_bool *)dataptr[2]) ||
+ *(npy_bool *)dataptr[3];
+ dataptr[0] += strides[0];
+ dataptr[1] += strides[1];
+ dataptr[2] += strides[2];
+ dataptr[3] += strides[3];
+#else
+ npy_bool temp = *(npy_bool *)dataptr[0];
+ int i;
+ for (i = 1; i < nop; ++i) {
+ temp = temp && *(npy_bool *)dataptr[i];
+ }
+ *(npy_bool *)dataptr[nop] = temp || *(npy_bool *)dataptr[i];
+ for (i = 0; i <= nop; ++i) {
+ dataptr[i] += strides[i];
+ }
+#endif
+ }
+}
+
+/**end repeat**/
+
+typedef void (*sum_of_products_fn)(int, char **, npy_intp *, npy_intp);
+
+static sum_of_products_fn
+get_sum_of_products_function(int nop, int type_num)
+{
+ switch (type_num) {
+/**begin repeat
+ * #name = bool,
+ * byte, short, int, long, longlong,
+ * ubyte, ushort, uint, ulong, ulonglong,
+ * half, float, double, longdouble,
+ * cfloat, cdouble, clongdouble#
+ * #NAME = BOOL,
+ * BYTE, SHORT, INT, LONG, LONGLONG,
+ * UBYTE, USHORT, UINT, ULONG, ULONGLONG,
+ * HALF, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
+ */
+ case NPY_@NAME@:
+ switch (nop) {
+/**begin repeat1
+ * #nop = 1, 2, 3, 1000#
+ * #noplabel = one, two, three, any#
+ */
+#if @nop@ <= 3
+ case @nop@:
+ return &@name@_sum_of_products_@noplabel@;
+#else
+ default:
+ return &@name@_sum_of_products_@noplabel@;
+#endif
+/**end repeat1**/
+ }
+/**end repeat**/
+ }
+
+ return NULL;
+}
+
/*
* Parses the subscripts for one operand into an output
* of 'ndim' labels
@@ -40,12 +239,9 @@ parse_operand_subscripts(char *subscripts, int length,
int i, idim, ndim_left, label;
int left_labels = 0, right_labels = 0;
- printf("Parsing operand %d subscripts\n", iop);
-
/* Process the labels from the end until the ellipsis */
idim = ndim-1;
for (i = length-1; i >= 0; --i) {
- printf("after ellipsis i = %d\n", i);
label = subscripts[i];
/* A label for an axis */
if (label > 0 && isalpha(label)) {
@@ -105,7 +301,6 @@ parse_operand_subscripts(char *subscripts, int length,
*/
if (i > 0) {
for (i = 0; i < length; ++i) {
- printf("before ellipsis i = %d\n", i);
label = subscripts[i];
/* A label for an axis */
if (label > 0 && isalnum(label)) {
@@ -153,7 +348,6 @@ parse_operand_subscripts(char *subscripts, int length,
*/
for (idim = 0; idim < ndim-1; ++idim) {
char *next;
- printf("duplicate check idim = %d\n", idim);
/* If this is a proper label, find any duplicates of it */
label = out_labels[idim];
if (label > 0) {
@@ -161,7 +355,6 @@ parse_operand_subscripts(char *subscripts, int length,
next = (char *)memchr(out_labels+idim+1, label,
ndim-idim-1);
while (next != NULL) {
- printf("inner check next=%s\n", next);
/* The offset from next to out_labels[idim] (negative) */
*next = (out_labels+idim)-next;
/* Search for the next matching label */
@@ -347,7 +540,6 @@ get_single_op_view(PyArrayObject *op, int iop, char *labels,
/* Match the labels in the operand with the output labels */
for (idim = 0; idim < ndim; ++idim) {
- printf("Matching label for dimension %d\n", idim);
label = labels[idim];
/* If this label says to merge axes, get the actual label */
if (label < 0) {
@@ -395,7 +587,6 @@ get_single_op_view(PyArrayObject *op, int iop, char *labels,
}
/* If we processed all the input axes, return a view */
if (idim == ndim) {
- printf("Returning a view\n");
Py_INCREF(PyArray_DESCR(op));
*ret = (PyArrayObject *)PyArray_NewFromDescr(
Py_TYPE(op),
@@ -447,7 +638,6 @@ get_combined_dims_view(PyArrayObject *op, int iop, char *labels)
/* Copy the dimensions and strides, except when collapsing */
icombine = 0;
for (idim = 0; idim < ndim; ++idim) {
- printf("Processing dimension %d\n", idim);
label = labels[idim];
/* If this label says to merge axes, get the actual label */
if (label < 0) {
@@ -548,6 +738,7 @@ prepare_op_axes(int ndim, int iop, char *labels, npy_intp *axes,
/* Otherwise map to the broadcast axis */
else {
axes[i] = ibroadcast;
+ --ibroadcast;
}
}
/* It's a labeled dimension, find the matching one */
@@ -588,6 +779,7 @@ prepare_op_axes(int ndim, int iop, char *labels, npy_intp *axes,
/* Otherwise map to the broadcast axis */
else {
axes[i] = ibroadcast;
+ ++ibroadcast;
}
}
/* It's a labeled dimension, find the matching one */
@@ -634,6 +826,7 @@ prepare_op_axes(int ndim, int iop, char *labels, npy_intp *axes,
/* Otherwise map to the broadcast axis */
else {
axes[i] = ibroadcast;
+ ++ibroadcast;
}
}
/* It's a labeled dimension, find the matching one */
@@ -719,6 +912,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
npy_uint32 op_flags[NPY_MAXARGS];
NpyIter *iter;
+ sum_of_products_fn sop;
/* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */
if (nop >= NPY_MAXARGS) {
@@ -733,7 +927,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
}
/* Parse the subscripts string into label_counts and op_labels */
- printf("Parsing input subscripts\n");
memset(label_counts, 0, sizeof(label_counts));
num_labels = 0;
for (iop = 0; iop < nop; ++iop) {
@@ -793,7 +986,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
* If there is no output signature, create one using each label
* that appeared once, in alphabetical order
*/
- printf("Parsing output subscripts\n");
if (subscripts[0] == '\0') {
char outsubscripts[NPY_MAXDIMS];
int length = 0;
@@ -850,7 +1042,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
* Process all the input ops, combining dimensions into their
* diagonal where specified.
*/
- printf("Processing inputs\n");
for (iop = 0; iop < nop; ++iop) {
char *labels = op_labels[iop];
int combine, ndim;
@@ -870,6 +1061,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
&ret)) {
return NULL;
}
+
if (ret != NULL) {
return ret;
}
@@ -968,7 +1160,8 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
NPY_ITER_BUFFERED|
NPY_ITER_DELAY_BUFALLOC|
NPY_ITER_GROWINNER|
- NPY_ITER_REDUCE_OK,
+ NPY_ITER_REDUCE_OK|
+ NPY_ITER_ZEROSIZE_OK,
order, casting,
op_flags, op_dtypes,
ndim_iter, op_axes, 0);
@@ -977,14 +1170,50 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
goto fail;
}
- /* Initialize the output to all zeros */
+ /* Initialize the output to all zeros and reset the iterator */
ret = NpyIter_GetOperandArray(iter)[nop];
Py_INCREF(ret);
PyArray_FillWithZero(ret);
NpyIter_Reset(iter, NULL);
- NpyIter_DebugPrint(iter);
+ sop = get_sum_of_products_function(nop,
+ NpyIter_GetDescrArray(iter)[0]->type_num);
+
+ /* Finally, the main loop */
+ if (sop == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid data type for einsum");
+ Py_DECREF(ret);
+ ret = NULL;
+ }
+ else if (NpyIter_GetIterSize(iter) != 0) {
+ NpyIter_IterNext_Fn iternext;
+ char **dataptr;
+ npy_intp *stride;
+ npy_intp *countptr;
+ NPY_BEGIN_THREADS_DEF;
+
+ iternext = NpyIter_GetIterNext(iter, NULL);
+ if (iternext == NULL) {
+ NpyIter_Deallocate(iter);
+ Py_DECREF(ret);
+ goto fail;
+ }
+ dataptr = NpyIter_GetDataPtrArray(iter);
+ stride = NpyIter_GetInnerStrideArray(iter);
+ countptr = NpyIter_GetInnerLoopSizePtr(iter);
+
+ NPY_BEGIN_THREADS;
+ do {
+ sop(nop, dataptr, stride, *countptr);
+ } while(iternext(iter));
+ NPY_END_THREADS;
+ }
+
NpyIter_Deallocate(iter);
+ for (iop = 0; iop < nop; ++iop) {
+ Py_DECREF(op[iop]);
+ }
return ret;
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src
index 30ca459ab..a1464ff92 100644
--- a/numpy/core/src/multiarray/new_iterator.c.src
+++ b/numpy/core/src/multiarray/new_iterator.c.src
@@ -18,7 +18,7 @@
#include "lowlevel_strided_loops.h"
/********** PRINTF DEBUG TRACING **************/
-#define NPY_IT_DBG_TRACING 1
+#define NPY_IT_DBG_TRACING 0
#if NPY_IT_DBG_TRACING
#define NPY_IT_DBG_PRINTF(...) printf(__VA_ARGS__)
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 62f10693c..3fe9ebd23 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -276,6 +276,80 @@ class TestEinSum(TestCase):
assert_(b.base is a)
assert_equal(b, [a[i,i,i] for i in range(3)])
+ # swap axes
+ a = np.arange(24).reshape(2,3,4)
+
+ b = np.einsum("ijk->jik", a)
+ assert_(b.base is a)
+ assert_equal(b, a.swapaxes(0,1))
+
+ def test_einsum_sums(self):
+ # sum(a, axis=-1)
+ a = np.arange(10)
+ assert_equal(np.einsum("i->", a), np.sum(a, axis=-1))
+
+ a = np.arange(24).reshape(2,3,4)
+ assert_equal(np.einsum("i->", a), np.sum(a, axis=-1))
+
+ # sum(a, axis=0)
+ a = np.arange(10)
+ assert_equal(np.einsum("i...->", a), np.sum(a, axis=0))
+
+ a = np.arange(24).reshape(2,3,4)
+ assert_equal(np.einsum("i...->", a), np.sum(a, axis=0))
+
+ # trace(a)
+ a = np.arange(25).reshape(5,5)
+ assert_equal(np.einsum("ii", a), np.trace(a))
+
+ # multiply(a, b)
+ a = np.arange(12).reshape(3,4)
+ b = np.arange(24).reshape(2,3,4)
+ assert_equal(np.einsum(",", a, b), np.multiply(a, b))
+
+ # inner(a,b)
+ a = np.arange(24).reshape(2,3,4)
+ b = np.arange(4)
+ assert_equal(np.einsum("i,i", a, b), np.inner(a, b))
+
+ a = np.arange(24).reshape(2,3,4)
+ b = np.arange(2)
+ assert_equal(np.einsum("i...,i...", a, b), np.inner(a.T, b.T).T)
+
+ # outer(a,b)
+ a = np.arange(3)+1
+ b = np.arange(4)+1
+ assert_equal(np.einsum("i,j", a, b), np.outer(a, b))
+
+ # matvec(a,b) / a.dot(b) where a is matrix, b is vector
+ a = np.arange(20).reshape(4,5)
+ b = np.arange(5)
+ assert_equal(np.einsum("ij,j", a, b), np.dot(a, b))
+
+ a = np.arange(20).reshape(4,5)
+ b = np.arange(5)
+ assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T))
+
+ # matmat(a,b) / a.dot(b) where a is matrix, b is matrix
+ a = np.arange(20).reshape(4,5)
+ b = np.arange(30).reshape(5,6)
+ assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b))
+
+ # tensordot(a, b)
+ a = np.arange(60.).reshape(3,4,5)
+ b = np.arange(24.).reshape(4,3,2)
+ assert_equal(np.einsum("ijk,jil->kl", a, b),
+ np.tensordot(a,b, axes=([1,0],[0,1])))
+
+ # logical_and(logical_and(a!=0, b!=0), c!=0)
+ a = np.array([1, 3, -2, 0, 12, 13, 0, 1])
+ b = np.array([0, 3.5, 0., -2, 0, 1, 3, 12])
+ c = np.array([True,True,False,True,True,False,True,True])
+ assert_equal(np.einsum("i,i,i->i", a, b, c,
+ dtype='?', casting='unsafe'),
+ logical_and(logical_and(a!=0, b!=0), c!=0))
+
+
class TestNonarrayArgs(TestCase):
# check that non-array arguments to functions wrap them in arrays
def test_squeeze(self):