summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-01-31 12:22:39 -0800
committerMark Wiebe <mwwiebe@gmail.com>2011-02-01 18:01:25 -0800
commitcdb0a56c8551182e566f0308fd9f4515d5e95d89 (patch)
tree2d19061816fdcf898d10b45bfb090beca5bf6f9e
parentabcdd9a62a1f83fa5d233477442cf0a34bde2143 (diff)
downloadnumpy-cdb0a56c8551182e566f0308fd9f4515d5e95d89.tar.gz
ENH: einsum: Add alternative einsum parameter method
This makes the following equivalent: einsum('ii', a) einsum(a, [0,0]) einsum('ii->i', a) einsum(a, [0,0], [0]) einsum('...i,...i->...', a, b) einsum(a, [Ellipsis,0], b, [Ellipsis,0], [Ellipsis])
-rw-r--r--numpy/add_newdocs.py30
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c276
-rw-r--r--numpy/core/tests/test_numeric.py163
3 files changed, 411 insertions, 58 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index f51860240..e749784d5 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -1517,6 +1517,10 @@ add_newdoc('numpy.core', 'einsum',
Evaluates the Einstein summation convention on the operands.
+ An alternative way to provide the subscripts and operands is as
+ einsum(op0, sublist0, op1, sublist1, ..., [sublistout]). The examples
+ below have corresponding einsum calls with the two parameter methods.
+
Using the Einstein summation convention, many common multi-dimensional
array operations can be represented in a simple fashion. This function
provides a way compute such summations.
@@ -1605,16 +1609,22 @@ add_newdoc('numpy.core', 'einsum',
>>> np.einsum('ii', a)
60
+ >>> np.einsum(a, [0,0])
+ 60
>>> np.trace(a)
60
>>> np.einsum('ii->i', a)
array([ 0, 6, 12, 18, 24])
+ >>> np.einsum(a, [0,0], [0])
+ array([ 0, 6, 12, 18, 24])
>>> np.diag(a)
array([ 0, 6, 12, 18, 24])
>>> np.einsum('ij,j', a, b)
array([ 30, 80, 130, 180, 230])
+ >>> np.einsum(a, [0,1], b, [1])
+ array([ 30, 80, 130, 180, 230])
>>> np.dot(a, b)
array([ 30, 80, 130, 180, 230])
@@ -1622,6 +1632,10 @@ add_newdoc('numpy.core', 'einsum',
array([[0, 3],
[1, 4],
[2, 5]])
+ >>> np.einsum(c, [1,0])
+ array([[0, 3],
+ [1, 4],
+ [2, 5]])
>>> c.T
array([[0, 3],
[1, 4],
@@ -1630,24 +1644,34 @@ add_newdoc('numpy.core', 'einsum',
>>> np.einsum('..., ...', 3, c)
array([[ 0, 3, 6],
[ 9, 12, 15]])
+ >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
+ array([[ 0, 3, 6],
+ [ 9, 12, 15]])
>>> np.multiply(3, c)
array([[ 0, 3, 6],
[ 9, 12, 15]])
>>> np.einsum('i,i', b, b)
30
+ >>> np.einsum(b, [0], b, [0])
+ 30
>>> np.inner(b,b)
30
>>> np.einsum('i,j', np.arange(2)+1, b)
array([[0, 1, 2, 3, 4],
[0, 2, 4, 6, 8]])
+ >>> np.einsum(np.arange(2)+1, [0], b, [1])
+ array([[0, 1, 2, 3, 4],
+ [0, 2, 4, 6, 8]])
>>> np.outer(np.arange(2)+1, b)
array([[0, 1, 2, 3, 4],
[0, 2, 4, 6, 8]])
>>> np.einsum('i...->...', a)
array([50, 55, 60, 65, 70])
+ >>> np.einsum(a, [0,Ellipsis], [Ellipsis])
+ array([50, 55, 60, 65, 70])
>>> np.sum(a, axis=0)
array([50, 55, 60, 65, 70])
@@ -1659,6 +1683,12 @@ add_newdoc('numpy.core', 'einsum',
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
+ >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
+ array([[ 4400., 4730.],
+ [ 4532., 4874.],
+ [ 4664., 5018.],
+ [ 4796., 5162.],
+ [ 4928., 5306.]])
>>> np.tensordot(a,b, axes=([1,0],[0,1]))
array([[ 4400., 4730.],
[ 4532., 4874.],
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index bbc3e8f23..9d9510a1f 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -1953,47 +1953,40 @@ array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
return _ARET(PyArray_MatrixProduct(a, v));
}
-static PyObject *
-array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
+static int
+einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
+ PyArrayObject **op)
{
- char *subscripts;
int i, nop;
- PyArrayObject *op[NPY_MAXARGS];
- NPY_ORDER order = NPY_KEEPORDER;
- NPY_CASTING casting = NPY_SAFE_CASTING;
- PyArrayObject *out = NULL;
- PyArray_Descr *dtype = NULL;
- PyObject *ret = NULL;
PyObject *subscripts_str;
- PyObject *str_obj = NULL;
- PyObject *str_key_obj = NULL;
nop = PyTuple_GET_SIZE(args) - 1;
if (nop <= 0) {
PyErr_SetString(PyExc_ValueError,
"must specify the einstein sum subscripts string "
"and at least one operand");
- return NULL;
+ return -1;
}
- else if (nop > NPY_MAXARGS) {
+ else if (nop >= NPY_MAXARGS) {
PyErr_SetString(PyExc_ValueError, "too many operands");
- return NULL;
+ return -1;
}
/* Get the subscripts string */
subscripts_str = PyTuple_GET_ITEM(args, 0);
if (PyUnicode_Check(subscripts_str)) {
- str_obj = PyUnicode_AsASCIIString(subscripts_str);
- if (str_obj == NULL) {
- return NULL;
+ *str_obj = PyUnicode_AsASCIIString(subscripts_str);
+ if (*str_obj == NULL) {
+ return -1;
}
- subscripts_str = str_obj;
+ subscripts_str = *str_obj;
}
- subscripts = PyBytes_AsString(subscripts_str);
+ *subscripts = PyBytes_AsString(subscripts_str);
if (subscripts == NULL) {
- Py_XDECREF(str_obj);
- return NULL;
+ Py_XDECREF(*str_obj);
+ *str_obj = NULL;
+ return -1;
}
/* Set the operands to NULL */
@@ -2004,17 +1997,235 @@ array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
/* Get the operands */
for (i = 0; i < nop; ++i) {
PyObject *obj = PyTuple_GET_ITEM(args, i+1);
- if (PyArray_Check(obj)) {
- Py_INCREF(obj);
- op[i] = (PyArrayObject *)obj;
+
+ op[i] = (PyArrayObject *)PyArray_FromAny(obj,
+ NULL, 0, 0, NPY_ENSUREARRAY, NULL);
+ if (op[i] == NULL) {
+ goto fail;
+ }
+ }
+
+ return nop;
+
+fail:
+ for (i = 0; i < nop; ++i) {
+ Py_XDECREF(op[i]);
+ op[i] = NULL;
+ }
+
+ return -1;
+}
+
+/*
+ * Converts a list of subscripts to a string.
+ *
+ * Returns -1 on error, the number of characters placed in subscripts
+ * otherwise.
+ */
+static int
+einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
+{
+ int ellipsis = 0, subindex = 0;
+ npy_intp i, size;
+ PyObject *item;
+
+ obj = PySequence_Fast(obj, "the subscripts for each operand must "
+ "be a list or a tuple");
+ if (obj == NULL) {
+ return -1;
+ }
+ size = PySequence_Size(obj);
+
+
+ for (i = 0; i < size; ++i) {
+ item = PySequence_Fast_GET_ITEM(obj, i);
+ /* Ellipsis */
+ if (item == Py_Ellipsis) {
+ if (ellipsis) {
+ PyErr_SetString(PyExc_ValueError,
+ "each subscripts list may have only one ellipsis");
+ Py_DECREF(obj);
+ return -1;
+ }
+ if (subindex + 3 >= subsize) {
+ PyErr_SetString(PyExc_ValueError,
+ "subscripts list is too long");
+ Py_DECREF(obj);
+ return -1;
+ }
+ subscripts[subindex++] = '.';
+ subscripts[subindex++] = '.';
+ subscripts[subindex++] = '.';
+ ellipsis = 1;
+ }
+ /* Subscript */
+ else if (PyInt_Check(item) || PyLong_Check(item)) {
+ long s = PyInt_AsLong(item);
+ if ( s < 0 || s > 2*26) {
+ PyErr_SetString(PyExc_ValueError,
+ "subscript is not within the valid range [0, 52]");
+ Py_DECREF(obj);
+ return -1;
+ }
+ if (s < 26) {
+ subscripts[subindex++] = 'A' + s;
+ }
+ else {
+ subscripts[subindex++] = 'a' + s;
+ }
+ if (subindex >= subsize) {
+ PyErr_SetString(PyExc_ValueError,
+ "subscripts list is too long");
+ Py_DECREF(obj);
+ return -1;
+ }
}
+ /* Invalid */
else {
- op[i] = (PyArrayObject *)PyArray_FromAny(obj,
- NULL, 0, 0, NPY_ENSUREARRAY, NULL);
- if (op[i] == NULL) {
- goto finish;
+ PyErr_SetString(PyExc_ValueError,
+ "each subscript must be either an integer "
+ "or an ellipsis");
+ Py_DECREF(obj);
+ return -1;
+ }
+ }
+
+ Py_DECREF(obj);
+
+ return subindex;
+}
+
+/*
+ * Fills in the subscripts, with maximum size subsize, and op,
+ * with the values in the tuple 'args'.
+ *
+ * Returns -1 on error, number of operands placed in op otherwise.
+ */
+static int
+einsum_sub_op_from_lists(PyObject *args,
+ char *subscripts, int subsize, PyArrayObject **op)
+{
+ int subindex = 0;
+ npy_intp i, nop;
+
+ nop = PyTuple_Size(args)/2;
+
+ if (nop == 0) {
+ PyErr_SetString(PyExc_ValueError, "must provide at least an "
+ "operand and a subscripts list to einsum");
+ return -1;
+ }
+ else if(nop >= NPY_MAXARGS) {
+ PyErr_SetString(PyExc_ValueError, "too many operands");
+ return -1;
+ }
+
+ /* Set the operands to NULL */
+ for (i = 0; i < nop; ++i) {
+ op[nop] = NULL;
+ }
+
+ /* Get the operands and build the subscript string */
+ for (i = 0; i < nop; ++i) {
+ PyObject *obj = PyTuple_GET_ITEM(args, 2*i);
+ int n;
+
+ /* Comma between the subscripts for each operand */
+ if (i != 0) {
+ subscripts[subindex++] = ',';
+ if (subindex >= subsize) {
+ PyErr_SetString(PyExc_ValueError,
+ "subscripts list is too long");
+ goto fail;
}
}
+
+ op[i] = (PyArrayObject *)PyArray_FromAny(obj,
+ NULL, 0, 0, NPY_ENSUREARRAY, NULL);
+ if (op[i] == NULL) {
+ goto fail;
+ }
+
+ obj = PyTuple_GET_ITEM(args, 2*i+1);
+ n = einsum_list_to_subscripts(obj, subscripts+subindex,
+ subsize-subindex);
+ if (n < 0) {
+ goto fail;
+ }
+ subindex += n;
+ }
+
+ /* Add the '->' to the string if provided */
+ if (PyTuple_Size(args) == 2*nop+1) {
+ PyObject *obj;
+ int n;
+
+ if (subindex + 2 >= subsize) {
+ PyErr_SetString(PyExc_ValueError,
+ "subscripts list is too long");
+ goto fail;
+ }
+ subscripts[subindex++] = '-';
+ subscripts[subindex++] = '>';
+
+ obj = PyTuple_GET_ITEM(args, 2*nop);
+ n = einsum_list_to_subscripts(obj, subscripts+subindex,
+ subsize-subindex);
+ if (n < 0) {
+ goto fail;
+ }
+ subindex += n;
+ }
+
+ /* NULL-terminate the subscripts string */
+ subscripts[subindex] = '\0';
+
+ return nop;
+
+fail:
+ for (i = 0; i < nop; ++i) {
+ Py_XDECREF(op[i]);
+ op[i] = NULL;
+ }
+
+ return -1;
+}
+
+static PyObject *
+array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
+{
+ char *subscripts = NULL, subscripts_buffer[256];
+ PyObject *str_obj = NULL, *str_key_obj = NULL;
+ PyObject *arg0;
+ int i, nop;
+ PyArrayObject *op[NPY_MAXARGS];
+ NPY_ORDER order = NPY_KEEPORDER;
+ NPY_CASTING casting = NPY_SAFE_CASTING;
+ PyArrayObject *out = NULL;
+ PyArray_Descr *dtype = NULL;
+ PyObject *ret = NULL;
+
+ if (PyTuple_GET_SIZE(args) < 1) {
+ PyErr_SetString(PyExc_ValueError,
+ "must specify the einstein sum subscripts string "
+ "and at least one operand, or at least one operand "
+ "and its corresponding subscripts list");
+ return NULL;
+ }
+ arg0 = PyTuple_GET_ITEM(args, 0);
+
+ /* einsum('i,j', a, b), einsum('i,j->ij', a, b) */
+ if (PyString_Check(arg0) || PyUnicode_Check(arg0)) {
+ nop = einsum_sub_op_from_str(args, &str_obj, &subscripts, op);
+ }
+ /* einsum(a, [0], b, [1]), einsum(a, [0], b, [1], [0,1]) */
+ else {
+ nop = einsum_sub_op_from_lists(args, subscripts_buffer,
+ sizeof(subscripts_buffer), op);
+ subscripts = subscripts_buffer;
+ }
+ if (nop <= 0) {
+ goto finish;
}
/* Get the keyword arguments */
@@ -2090,6 +2301,7 @@ finish:
Py_XDECREF(dtype);
Py_XDECREF(str_obj);
Py_XDECREF(str_key_obj);
+ /* out is a borrowed reference */
return ret;
}
@@ -2722,20 +2934,20 @@ compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
int cmp_op;
Bool rstrip;
char *cmp_str;
- Py_ssize_t strlen;
+ Py_ssize_t strlength;
PyObject *res = NULL;
static char msg[] = "comparision must be '==', '!=', '<', '>', '<=', '>='";
static char *kwlist[] = {"a1", "a2", "cmp", "rstrip", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOs#O&", kwlist,
- &array, &other, &cmp_str, &strlen,
+ &array, &other, &cmp_str, &strlength,
PyArray_BoolConverter, &rstrip)) {
return NULL;
}
- if (strlen < 1 || strlen > 2) {
+ if (strlength < 1 || strlength > 2) {
goto err;
}
- if (strlen > 1) {
+ if (strlength > 1) {
if (cmp_str[1] != '=') {
goto err;
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 68b40ff9b..34d295a8b 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -223,10 +223,17 @@ class TestEinSum(TestCase):
b = np.einsum("...", a)
assert_(b.base is a)
+ b = np.einsum(a, [Ellipsis])
+ assert_(b.base is a)
+
b = np.einsum("ij", a)
assert_(b.base is a)
assert_equal(b, a)
+ b = np.einsum(a, [0,1])
+ assert_(b.base is a)
+ assert_equal(b, a)
+
# transpose
a = np.arange(6).reshape(2,3)
@@ -234,6 +241,10 @@ class TestEinSum(TestCase):
assert_(b.base is a)
assert_equal(b, a.T)
+ b = np.einsum(a, [1,0])
+ assert_(b.base is a)
+ assert_equal(b, a.T)
+
# diagonal
a = np.arange(9).reshape(3,3)
@@ -241,6 +252,10 @@ class TestEinSum(TestCase):
assert_(b.base is a)
assert_equal(b, [a[i,i] for i in range(3)])
+ b = np.einsum(a, [0,0], [0])
+ assert_(b.base is a)
+ assert_equal(b, [a[i,i] for i in range(3)])
+
# diagonal with various ways of broadcasting an additional dimension
a = np.arange(27).reshape(3,3,3)
@@ -248,32 +263,62 @@ class TestEinSum(TestCase):
assert_(b.base is a)
assert_equal(b, [[x[i,i] for i in range(3)] for x in a])
+ b = np.einsum(a, [Ellipsis,0,0], [Ellipsis,0])
+ assert_(b.base is a)
+ assert_equal(b, [[x[i,i] for i in range(3)] for x in a])
+
b = np.einsum("ii...->...i", a)
assert_(b.base is a)
assert_equal(b, [[x[i,i] for i in range(3)]
for x in a.transpose(2,0,1)])
+ b = np.einsum(a, [0,0,Ellipsis], [Ellipsis,0])
+ assert_(b.base is a)
+ assert_equal(b, [[x[i,i] for i in range(3)]
+ for x in a.transpose(2,0,1)])
+
b = np.einsum("...ii->i...", a)
assert_(b.base is a)
assert_equal(b, [a[:,i,i] for i in range(3)])
+ b = np.einsum(a, [Ellipsis,0,0], [0,Ellipsis])
+ assert_(b.base is a)
+ assert_equal(b, [a[:,i,i] for i in range(3)])
+
b = np.einsum("jii->ij", a)
assert_(b.base is a)
assert_equal(b, [a[:,i,i] for i in range(3)])
+ b = np.einsum(a, [1,0,0], [0,1])
+ assert_(b.base is a)
+ assert_equal(b, [a[:,i,i] for i in range(3)])
+
b = np.einsum("ii...->i...", a)
assert_(b.base is a)
assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)])
+ b = np.einsum(a, [0,0,Ellipsis], [0,Ellipsis])
+ assert_(b.base is a)
+ assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)])
+
b = np.einsum("i...i->i...", a)
assert_(b.base is a)
assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)])
+ b = np.einsum(a, [0,Ellipsis,0], [0,Ellipsis])
+ assert_(b.base is a)
+ assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)])
+
b = np.einsum("i...i->...i", a)
assert_(b.base is a)
assert_equal(b, [[x[i,i] for i in range(3)]
for x in a.transpose(1,0,2)])
+ b = np.einsum(a, [0,Ellipsis,0], [Ellipsis,0])
+ assert_(b.base is a)
+ assert_equal(b, [[x[i,i] for i in range(3)]
+ for x in a.transpose(1,0,2)])
+
# triple diagonal
a = np.arange(27).reshape(3,3,3)
@@ -281,6 +326,10 @@ class TestEinSum(TestCase):
assert_(b.base is a)
assert_equal(b, [a[i,i,i] for i in range(3)])
+ b = np.einsum(a, [0,0,0], [0])
+ assert_(b.base is a)
+ assert_equal(b, [a[i,i,i] for i in range(3)])
+
# swap axes
a = np.arange(24).reshape(2,3,4)
@@ -288,52 +337,75 @@ class TestEinSum(TestCase):
assert_(b.base is a)
assert_equal(b, a.swapaxes(0,1))
+ b = np.einsum(a, [0,1,2], [1,0,2])
+ assert_(b.base is a)
+ assert_equal(b, a.swapaxes(0,1))
+
def check_einsum_sums(self, dtype):
# sum(a, axis=-1)
- a = np.arange(10, dtype=dtype)
- assert_equal(np.einsum("i->", a), np.sum(a, axis=-1))
+ for n in range(1,17):
+ a = np.arange(n, dtype=dtype)
+ assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype))
+ assert_equal(np.einsum(a, [0], []),
+ np.sum(a, axis=-1).astype(dtype))
for n in range(1,17):
a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
assert_equal(np.einsum("...i->...", a),
np.sum(a, axis=-1).astype(dtype))
+ assert_equal(np.einsum(a, [Ellipsis,0], [Ellipsis]),
+ np.sum(a, axis=-1).astype(dtype))
# sum(a, axis=0)
for n in range(1,17):
a = np.arange(2*n, dtype=dtype).reshape(2,n)
assert_equal(np.einsum("i...->...", a),
np.sum(a, axis=0).astype(dtype))
+ assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]),
+ np.sum(a, axis=0).astype(dtype))
for n in range(1,17):
a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
assert_equal(np.einsum("i...->...", a),
np.sum(a, axis=0).astype(dtype))
+ assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]),
+ np.sum(a, axis=0).astype(dtype))
# trace(a)
- a = np.arange(25, dtype=dtype).reshape(5,5)
- assert_equal(np.einsum("ii", a), np.trace(a))
+ for n in range(1,17):
+ a = np.arange(n*n, dtype=dtype).reshape(n,n)
+ assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype))
+ assert_equal(np.einsum(a, [0,0]), np.trace(a).astype(dtype))
# multiply(a, b)
for n in range(1,17):
a = np.arange(3*n, dtype=dtype).reshape(3,n)
b = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b))
+ assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]),
+ np.multiply(a, b))
# inner(a,b)
for n in range(1,17):
a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
b = np.arange(n, dtype=dtype)
assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b))
+ assert_equal(np.einsum(a, [Ellipsis,0], b, [Ellipsis,0]),
+ np.inner(a, b))
for n in range(1,11):
a = np.arange(n*3*2, dtype=dtype).reshape(n,3,2)
b = np.arange(n, dtype=dtype)
assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T)
+ assert_equal(np.einsum(a, [0,Ellipsis], b, [0,Ellipsis]),
+ np.inner(a.T, b.T).T)
# outer(a,b)
- a = np.arange(3, dtype=dtype)+1
- b = np.arange(4, dtype=dtype)+1
- assert_equal(np.einsum("i,j", a, b), np.outer(a, b))
+ for n in range(1,17):
+ a = np.arange(3, dtype=dtype)+1
+ b = np.arange(n, dtype=dtype)+1
+ assert_equal(np.einsum("i,j", a, b), np.outer(a, b))
+ assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b))
# Suppress the complex warnings for the 'as f8' tests
ctx = WarningManager()
@@ -346,41 +418,61 @@ class TestEinSum(TestCase):
a = np.arange(4*n, dtype=dtype).reshape(4,n)
b = np.arange(n, dtype=dtype)
assert_equal(np.einsum("ij, j", a, b), np.dot(a, b))
+ assert_equal(np.einsum(a, [0,1], b, [1]), np.dot(a, b))
- for n in range(1,17):
- a = np.arange(4*n, dtype=dtype).reshape(4,n)
- b = np.arange(n, dtype=dtype)
c = np.arange(4, dtype=dtype)
np.einsum("ij,j", a, b, out=c,
dtype='f8', casting='unsafe')
assert_equal(c,
np.dot(a.astype('f8'),
b.astype('f8')).astype(dtype))
+ c[...] = 0
+ np.einsum(a, [0,1], b, [1], out=c,
+ dtype='f8', casting='unsafe')
+ assert_equal(c,
+ np.dot(a.astype('f8'),
+ b.astype('f8')).astype(dtype))
for n in range(1,17):
a = np.arange(4*n, dtype=dtype).reshape(4,n)
b = np.arange(n, dtype=dtype)
assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T))
+ assert_equal(np.einsum(a.T, [1,0], b.T, [1]), np.dot(b.T, a.T))
- a = np.arange(4*n, dtype=dtype).reshape(4,n)
- b = np.arange(n, dtype=dtype)
c = np.arange(4, dtype=dtype)
np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe')
assert_equal(c,
np.dot(b.T.astype('f8'),
a.T.astype('f8')).astype(dtype))
+ c[...] = 0
+ np.einsum(a.T, [1,0], b.T, [1], out=c,
+ dtype='f8', casting='unsafe')
+ assert_equal(c,
+ np.dot(b.T.astype('f8'),
+ a.T.astype('f8')).astype(dtype))
# matmat(a,b) / a.dot(b) where a is matrix, b is matrix
- a = np.arange(20, dtype=dtype).reshape(4,5)
- b = np.arange(30, dtype=dtype).reshape(5,6)
- assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b))
+ for n in range(1,17):
+ if n < 8 or dtype != 'f2':
+ a = np.arange(4*n, dtype=dtype).reshape(4,n)
+ b = np.arange(n*6, dtype=dtype).reshape(n,6)
+ assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b))
+ assert_equal(np.einsum(a, [0,1], b, [1,2]), np.dot(a, b))
- a = np.arange(20, dtype=dtype).reshape(4,5)
- b = np.arange(30, dtype=dtype).reshape(5,6)
- c = np.arange(24, dtype=dtype).reshape(4,6)
- np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe')
- assert_equal(c,
- np.dot(a.astype('f8'), b.astype('f8')).astype(dtype))
+ for n in range(1,17):
+ a = np.arange(4*n, dtype=dtype).reshape(4,n)
+ b = np.arange(n*6, dtype=dtype).reshape(n,6)
+ c = np.arange(24, dtype=dtype).reshape(4,6)
+ np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe')
+ assert_equal(c,
+ np.dot(a.astype('f8'),
+ b.astype('f8')).astype(dtype))
+ c[...] = 0
+ np.einsum(a, [0,1], b, [1,2], out=c,
+ dtype='f8', casting='unsafe')
+ assert_equal(c,
+ np.dot(a.astype('f8'),
+ b.astype('f8')).astype(dtype))
# matrix triple product (note this is not currently an efficient
# way to multiply 3 matrices)
@@ -390,15 +482,19 @@ class TestEinSum(TestCase):
if dtype != 'f2':
assert_equal(np.einsum("ij,jk,kl", a, b, c),
a.dot(b).dot(c))
+ assert_equal(np.einsum(a, [0,1], b, [1,2], c, [2,3]),
+ a.dot(b).dot(c))
- a = np.arange(12, dtype=dtype).reshape(3,4)
- b = np.arange(20, dtype=dtype).reshape(4,5)
- c = np.arange(30, dtype=dtype).reshape(5,6)
d = np.arange(18, dtype=dtype).reshape(3,6)
np.einsum("ij,jk,kl", a, b, c, out=d,
dtype='f8', casting='unsafe')
assert_equal(d, a.astype('f8').dot(b.astype('f8')
).dot(c.astype('f8')).astype(dtype))
+ d[...] = 0
+ np.einsum(a, [0,1], b, [1,2], c, [2,3], out=d,
+ dtype='f8', casting='unsafe')
+ assert_equal(d, a.astype('f8').dot(b.astype('f8')
+ ).dot(c.astype('f8')).astype(dtype))
# tensordot(a, b)
if np.dtype(dtype) != np.dtype('f2'):
@@ -406,14 +502,19 @@ class TestEinSum(TestCase):
b = np.arange(24, dtype=dtype).reshape(4,3,2)
assert_equal(np.einsum("ijk, jil -> kl", a, b),
np.tensordot(a,b, axes=([1,0],[0,1])))
+ assert_equal(np.einsum(a, [0,1,2], b, [1,0,3], [2,3]),
+ np.tensordot(a,b, axes=([1,0],[0,1])))
- a = np.arange(60, dtype=dtype).reshape(3,4,5)
- b = np.arange(24, dtype=dtype).reshape(4,3,2)
c = np.arange(10, dtype=dtype).reshape(5,2)
np.einsum("ijk,jil->kl", a, b, out=c,
dtype='f8', casting='unsafe')
assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'),
axes=([1,0],[0,1])).astype(dtype))
+ c[...] = 0
+ np.einsum(a, [0,1,2], b, [1,0,3], [2,3], out=c,
+ dtype='f8', casting='unsafe')
+ assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'),
+ axes=([1,0],[0,1])).astype(dtype))
finally:
ctx.__exit__()
@@ -424,10 +525,15 @@ class TestEinSum(TestCase):
assert_equal(np.einsum("i,i,i->i", a, b, c,
dtype='?', casting='unsafe'),
logical_and(logical_and(a!=0, b!=0), c!=0))
+ assert_equal(np.einsum(a, [0], b, [0], c, [0], [0],
+ dtype='?', casting='unsafe'),
+ logical_and(logical_and(a!=0, b!=0), c!=0))
a = np.arange(9, dtype=dtype)
assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a))
+ assert_equal(np.einsum(3, [], a, [0], []), 3*np.sum(a))
assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a))
+ assert_equal(np.einsum(a, [0], 3, [], []), 3*np.sum(a))
# Various stride0, contiguous, and SSE aligned variants
for n in range(1,25):
@@ -451,10 +557,15 @@ class TestEinSum(TestCase):
# An object array, summed as the data type
a = np.arange(9, dtype=object)
+
b = np.einsum("i->", a, dtype=dtype, casting='unsafe')
assert_equal(b, np.sum(a))
assert_equal(b.dtype, np.dtype(dtype))
+ b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe')
+ assert_equal(b, np.sum(a))
+ assert_equal(b.dtype, np.dtype(dtype))
+
def test_einsum_sums_int8(self):
self.check_einsum_sums('i1');