diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/add_newdocs.py | 30 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 276 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 163 |
3 files changed, 411 insertions, 58 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index f51860240..e749784d5 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -1517,6 +1517,10 @@ add_newdoc('numpy.core', 'einsum', Evaluates the Einstein summation convention on the operands. + An alternative way to provide the subscripts and operands is as + einsum(op0, sublist0, op1, sublist1, ..., [sublistout]). The examples + below have corresponding einsum calls with the two parameter methods. + Using the Einstein summation convention, many common multi-dimensional array operations can be represented in a simple fashion. This function provides a way compute such summations. @@ -1605,16 +1609,22 @@ add_newdoc('numpy.core', 'einsum', >>> np.einsum('ii', a) 60 + >>> np.einsum(a, [0,0]) + 60 >>> np.trace(a) 60 >>> np.einsum('ii->i', a) array([ 0, 6, 12, 18, 24]) + >>> np.einsum(a, [0,0], [0]) + array([ 0, 6, 12, 18, 24]) >>> np.diag(a) array([ 0, 6, 12, 18, 24]) >>> np.einsum('ij,j', a, b) array([ 30, 80, 130, 180, 230]) + >>> np.einsum(a, [0,1], b, [1]) + array([ 30, 80, 130, 180, 230]) >>> np.dot(a, b) array([ 30, 80, 130, 180, 230]) @@ -1622,6 +1632,10 @@ add_newdoc('numpy.core', 'einsum', array([[0, 3], [1, 4], [2, 5]]) + >>> np.einsum(c, [1,0]) + array([[0, 3], + [1, 4], + [2, 5]]) >>> c.T array([[0, 3], [1, 4], @@ -1630,24 +1644,34 @@ add_newdoc('numpy.core', 'einsum', >>> np.einsum('..., ...', 3, c) array([[ 0, 3, 6], [ 9, 12, 15]]) + >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) + array([[ 0, 3, 6], + [ 9, 12, 15]]) >>> np.multiply(3, c) array([[ 0, 3, 6], [ 9, 12, 15]]) >>> np.einsum('i,i', b, b) 30 + >>> np.einsum(b, [0], b, [0]) + 30 >>> np.inner(b,b) 30 >>> np.einsum('i,j', np.arange(2)+1, b) array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]) + >>> np.einsum(np.arange(2)+1, [0], b, [1]) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) >>> np.outer(np.arange(2)+1, b) array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]) >>> np.einsum('i...->...', a) array([50, 55, 60, 65, 70]) + >>> np.einsum(a, [0,Ellipsis], [Ellipsis]) + array([50, 55, 60, 65, 70]) >>> np.sum(a, axis=0) array([50, 55, 60, 65, 70]) @@ -1659,6 +1683,12 @@ add_newdoc('numpy.core', 'einsum', [ 4664., 5018.], [ 4796., 5162.], [ 4928., 5306.]]) + >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) >>> np.tensordot(a,b, axes=([1,0],[0,1])) array([[ 4400., 4730.], [ 4532., 4874.], diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index bbc3e8f23..9d9510a1f 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1953,47 +1953,40 @@ array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args) return _ARET(PyArray_MatrixProduct(a, v)); } -static PyObject * -array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) +static int +einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts, + PyArrayObject **op) { - char *subscripts; int i, nop; - PyArrayObject *op[NPY_MAXARGS]; - NPY_ORDER order = NPY_KEEPORDER; - NPY_CASTING casting = NPY_SAFE_CASTING; - PyArrayObject *out = NULL; - PyArray_Descr *dtype = NULL; - PyObject *ret = NULL; PyObject *subscripts_str; - PyObject *str_obj = NULL; - PyObject *str_key_obj = NULL; nop = PyTuple_GET_SIZE(args) - 1; if (nop <= 0) { PyErr_SetString(PyExc_ValueError, "must specify the einstein sum subscripts string " "and at least one operand"); - return NULL; + return -1; } - else if (nop > NPY_MAXARGS) { + else if (nop >= NPY_MAXARGS) { PyErr_SetString(PyExc_ValueError, "too many operands"); - return NULL; + return -1; } /* Get the subscripts string */ subscripts_str = PyTuple_GET_ITEM(args, 0); if (PyUnicode_Check(subscripts_str)) { - str_obj = PyUnicode_AsASCIIString(subscripts_str); - if (str_obj == NULL) { - return NULL; + *str_obj = PyUnicode_AsASCIIString(subscripts_str); + if (*str_obj == NULL) { + return -1; } - subscripts_str = str_obj; + subscripts_str = *str_obj; } - subscripts = PyBytes_AsString(subscripts_str); + *subscripts = PyBytes_AsString(subscripts_str); if (subscripts == NULL) { - Py_XDECREF(str_obj); - return NULL; + Py_XDECREF(*str_obj); + *str_obj = NULL; + return -1; } /* Set the operands to NULL */ @@ -2004,17 +1997,235 @@ array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) /* Get the operands */ for (i = 0; i < nop; ++i) { PyObject *obj = PyTuple_GET_ITEM(args, i+1); - if (PyArray_Check(obj)) { - Py_INCREF(obj); - op[i] = (PyArrayObject *)obj; + + op[i] = (PyArrayObject *)PyArray_FromAny(obj, + NULL, 0, 0, NPY_ENSUREARRAY, NULL); + if (op[i] == NULL) { + goto fail; + } + } + + return nop; + +fail: + for (i = 0; i < nop; ++i) { + Py_XDECREF(op[i]); + op[i] = NULL; + } + + return -1; +} + +/* + * Converts a list of subscripts to a string. + * + * Returns -1 on error, the number of characters placed in subscripts + * otherwise. + */ +static int +einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) +{ + int ellipsis = 0, subindex = 0; + npy_intp i, size; + PyObject *item; + + obj = PySequence_Fast(obj, "the subscripts for each operand must " + "be a list or a tuple"); + if (obj == NULL) { + return -1; + } + size = PySequence_Size(obj); + + + for (i = 0; i < size; ++i) { + item = PySequence_Fast_GET_ITEM(obj, i); + /* Ellipsis */ + if (item == Py_Ellipsis) { + if (ellipsis) { + PyErr_SetString(PyExc_ValueError, + "each subscripts list may have only one ellipsis"); + Py_DECREF(obj); + return -1; + } + if (subindex + 3 >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + Py_DECREF(obj); + return -1; + } + subscripts[subindex++] = '.'; + subscripts[subindex++] = '.'; + subscripts[subindex++] = '.'; + ellipsis = 1; + } + /* Subscript */ + else if (PyInt_Check(item) || PyLong_Check(item)) { + long s = PyInt_AsLong(item); + if ( s < 0 || s > 2*26) { + PyErr_SetString(PyExc_ValueError, + "subscript is not within the valid range [0, 52]"); + Py_DECREF(obj); + return -1; + } + if (s < 26) { + subscripts[subindex++] = 'A' + s; + } + else { + subscripts[subindex++] = 'a' + s; + } + if (subindex >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + Py_DECREF(obj); + return -1; + } } + /* Invalid */ else { - op[i] = (PyArrayObject *)PyArray_FromAny(obj, - NULL, 0, 0, NPY_ENSUREARRAY, NULL); - if (op[i] == NULL) { - goto finish; + PyErr_SetString(PyExc_ValueError, + "each subscript must be either an integer " + "or an ellipsis"); + Py_DECREF(obj); + return -1; + } + } + + Py_DECREF(obj); + + return subindex; +} + +/* + * Fills in the subscripts, with maximum size subsize, and op, + * with the values in the tuple 'args'. + * + * Returns -1 on error, number of operands placed in op otherwise. + */ +static int +einsum_sub_op_from_lists(PyObject *args, + char *subscripts, int subsize, PyArrayObject **op) +{ + int subindex = 0; + npy_intp i, nop; + + nop = PyTuple_Size(args)/2; + + if (nop == 0) { + PyErr_SetString(PyExc_ValueError, "must provide at least an " + "operand and a subscripts list to einsum"); + return -1; + } + else if(nop >= NPY_MAXARGS) { + PyErr_SetString(PyExc_ValueError, "too many operands"); + return -1; + } + + /* Set the operands to NULL */ + for (i = 0; i < nop; ++i) { + op[nop] = NULL; + } + + /* Get the operands and build the subscript string */ + for (i = 0; i < nop; ++i) { + PyObject *obj = PyTuple_GET_ITEM(args, 2*i); + int n; + + /* Comma between the subscripts for each operand */ + if (i != 0) { + subscripts[subindex++] = ','; + if (subindex >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + goto fail; } } + + op[i] = (PyArrayObject *)PyArray_FromAny(obj, + NULL, 0, 0, NPY_ENSUREARRAY, NULL); + if (op[i] == NULL) { + goto fail; + } + + obj = PyTuple_GET_ITEM(args, 2*i+1); + n = einsum_list_to_subscripts(obj, subscripts+subindex, + subsize-subindex); + if (n < 0) { + goto fail; + } + subindex += n; + } + + /* Add the '->' to the string if provided */ + if (PyTuple_Size(args) == 2*nop+1) { + PyObject *obj; + int n; + + if (subindex + 2 >= subsize) { + PyErr_SetString(PyExc_ValueError, + "subscripts list is too long"); + goto fail; + } + subscripts[subindex++] = '-'; + subscripts[subindex++] = '>'; + + obj = PyTuple_GET_ITEM(args, 2*nop); + n = einsum_list_to_subscripts(obj, subscripts+subindex, + subsize-subindex); + if (n < 0) { + goto fail; + } + subindex += n; + } + + /* NULL-terminate the subscripts string */ + subscripts[subindex] = '\0'; + + return nop; + +fail: + for (i = 0; i < nop; ++i) { + Py_XDECREF(op[i]); + op[i] = NULL; + } + + return -1; +} + +static PyObject * +array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) +{ + char *subscripts = NULL, subscripts_buffer[256]; + PyObject *str_obj = NULL, *str_key_obj = NULL; + PyObject *arg0; + int i, nop; + PyArrayObject *op[NPY_MAXARGS]; + NPY_ORDER order = NPY_KEEPORDER; + NPY_CASTING casting = NPY_SAFE_CASTING; + PyArrayObject *out = NULL; + PyArray_Descr *dtype = NULL; + PyObject *ret = NULL; + + if (PyTuple_GET_SIZE(args) < 1) { + PyErr_SetString(PyExc_ValueError, + "must specify the einstein sum subscripts string " + "and at least one operand, or at least one operand " + "and its corresponding subscripts list"); + return NULL; + } + arg0 = PyTuple_GET_ITEM(args, 0); + + /* einsum('i,j', a, b), einsum('i,j->ij', a, b) */ + if (PyString_Check(arg0) || PyUnicode_Check(arg0)) { + nop = einsum_sub_op_from_str(args, &str_obj, &subscripts, op); + } + /* einsum(a, [0], b, [1]), einsum(a, [0], b, [1], [0,1]) */ + else { + nop = einsum_sub_op_from_lists(args, subscripts_buffer, + sizeof(subscripts_buffer), op); + subscripts = subscripts_buffer; + } + if (nop <= 0) { + goto finish; } /* Get the keyword arguments */ @@ -2090,6 +2301,7 @@ finish: Py_XDECREF(dtype); Py_XDECREF(str_obj); Py_XDECREF(str_key_obj); + /* out is a borrowed reference */ return ret; } @@ -2722,20 +2934,20 @@ compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds) int cmp_op; Bool rstrip; char *cmp_str; - Py_ssize_t strlen; + Py_ssize_t strlength; PyObject *res = NULL; static char msg[] = "comparision must be '==', '!=', '<', '>', '<=', '>='"; static char *kwlist[] = {"a1", "a2", "cmp", "rstrip", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOs#O&", kwlist, - &array, &other, &cmp_str, &strlen, + &array, &other, &cmp_str, &strlength, PyArray_BoolConverter, &rstrip)) { return NULL; } - if (strlen < 1 || strlen > 2) { + if (strlength < 1 || strlength > 2) { goto err; } - if (strlen > 1) { + if (strlength > 1) { if (cmp_str[1] != '=') { goto err; } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 68b40ff9b..34d295a8b 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -223,10 +223,17 @@ class TestEinSum(TestCase): b = np.einsum("...", a) assert_(b.base is a) + b = np.einsum(a, [Ellipsis]) + assert_(b.base is a) + b = np.einsum("ij", a) assert_(b.base is a) assert_equal(b, a) + b = np.einsum(a, [0,1]) + assert_(b.base is a) + assert_equal(b, a) + # transpose a = np.arange(6).reshape(2,3) @@ -234,6 +241,10 @@ class TestEinSum(TestCase): assert_(b.base is a) assert_equal(b, a.T) + b = np.einsum(a, [1,0]) + assert_(b.base is a) + assert_equal(b, a.T) + # diagonal a = np.arange(9).reshape(3,3) @@ -241,6 +252,10 @@ class TestEinSum(TestCase): assert_(b.base is a) assert_equal(b, [a[i,i] for i in range(3)]) + b = np.einsum(a, [0,0], [0]) + assert_(b.base is a) + assert_equal(b, [a[i,i] for i in range(3)]) + # diagonal with various ways of broadcasting an additional dimension a = np.arange(27).reshape(3,3,3) @@ -248,32 +263,62 @@ class TestEinSum(TestCase): assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) + b = np.einsum(a, [Ellipsis,0,0], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] for x in a]) + b = np.einsum("ii...->...i", a) assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a.transpose(2,0,1)]) + b = np.einsum(a, [0,0,Ellipsis], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(2,0,1)]) + b = np.einsum("...ii->i...", a) assert_(b.base is a) assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum(a, [Ellipsis,0,0], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum("jii->ij", a) assert_(b.base is a) assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum(a, [1,0,0], [0,1]) + assert_(b.base is a) + assert_equal(b, [a[:,i,i] for i in range(3)]) + b = np.einsum("ii...->i...", a) assert_(b.base is a) assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) + b = np.einsum(a, [0,0,Ellipsis], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a.transpose(2,0,1)[:,i,i] for i in range(3)]) + b = np.einsum("i...i->i...", a) assert_(b.base is a) assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) + b = np.einsum(a, [0,Ellipsis,0], [0,Ellipsis]) + assert_(b.base is a) + assert_equal(b, [a.transpose(1,0,2)[:,i,i] for i in range(3)]) + b = np.einsum("i...i->...i", a) assert_(b.base is a) assert_equal(b, [[x[i,i] for i in range(3)] for x in a.transpose(1,0,2)]) + b = np.einsum(a, [0,Ellipsis,0], [Ellipsis,0]) + assert_(b.base is a) + assert_equal(b, [[x[i,i] for i in range(3)] + for x in a.transpose(1,0,2)]) + # triple diagonal a = np.arange(27).reshape(3,3,3) @@ -281,6 +326,10 @@ class TestEinSum(TestCase): assert_(b.base is a) assert_equal(b, [a[i,i,i] for i in range(3)]) + b = np.einsum(a, [0,0,0], [0]) + assert_(b.base is a) + assert_equal(b, [a[i,i,i] for i in range(3)]) + # swap axes a = np.arange(24).reshape(2,3,4) @@ -288,52 +337,75 @@ class TestEinSum(TestCase): assert_(b.base is a) assert_equal(b, a.swapaxes(0,1)) + b = np.einsum(a, [0,1,2], [1,0,2]) + assert_(b.base is a) + assert_equal(b, a.swapaxes(0,1)) + def check_einsum_sums(self, dtype): # sum(a, axis=-1) - a = np.arange(10, dtype=dtype) - assert_equal(np.einsum("i->", a), np.sum(a, axis=-1)) + for n in range(1,17): + a = np.arange(n, dtype=dtype) + assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype)) + assert_equal(np.einsum(a, [0], []), + np.sum(a, axis=-1).astype(dtype)) for n in range(1,17): a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) assert_equal(np.einsum("...i->...", a), np.sum(a, axis=-1).astype(dtype)) + assert_equal(np.einsum(a, [Ellipsis,0], [Ellipsis]), + np.sum(a, axis=-1).astype(dtype)) # sum(a, axis=0) for n in range(1,17): a = np.arange(2*n, dtype=dtype).reshape(2,n) assert_equal(np.einsum("i...->...", a), np.sum(a, axis=0).astype(dtype)) + assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), + np.sum(a, axis=0).astype(dtype)) for n in range(1,17): a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) assert_equal(np.einsum("i...->...", a), np.sum(a, axis=0).astype(dtype)) + assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]), + np.sum(a, axis=0).astype(dtype)) # trace(a) - a = np.arange(25, dtype=dtype).reshape(5,5) - assert_equal(np.einsum("ii", a), np.trace(a)) + for n in range(1,17): + a = np.arange(n*n, dtype=dtype).reshape(n,n) + assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype)) + assert_equal(np.einsum(a, [0,0]), np.trace(a).astype(dtype)) # multiply(a, b) for n in range(1,17): a = np.arange(3*n, dtype=dtype).reshape(3,n) b = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b)) + assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]), + np.multiply(a, b)) # inner(a,b) for n in range(1,17): a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b)) + assert_equal(np.einsum(a, [Ellipsis,0], b, [Ellipsis,0]), + np.inner(a, b)) for n in range(1,11): a = np.arange(n*3*2, dtype=dtype).reshape(n,3,2) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T) + assert_equal(np.einsum(a, [0,Ellipsis], b, [0,Ellipsis]), + np.inner(a.T, b.T).T) # outer(a,b) - a = np.arange(3, dtype=dtype)+1 - b = np.arange(4, dtype=dtype)+1 - assert_equal(np.einsum("i,j", a, b), np.outer(a, b)) + for n in range(1,17): + a = np.arange(3, dtype=dtype)+1 + b = np.arange(n, dtype=dtype)+1 + assert_equal(np.einsum("i,j", a, b), np.outer(a, b)) + assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b)) # Suppress the complex warnings for the 'as f8' tests ctx = WarningManager() @@ -346,41 +418,61 @@ class TestEinSum(TestCase): a = np.arange(4*n, dtype=dtype).reshape(4,n) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("ij, j", a, b), np.dot(a, b)) + assert_equal(np.einsum(a, [0,1], b, [1]), np.dot(a, b)) - for n in range(1,17): - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n, dtype=dtype) c = np.arange(4, dtype=dtype) np.einsum("ij,j", a, b, out=c, dtype='f8', casting='unsafe') assert_equal(c, np.dot(a.astype('f8'), b.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1], b, [1], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) for n in range(1,17): a = np.arange(4*n, dtype=dtype).reshape(4,n) b = np.arange(n, dtype=dtype) assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T)) + assert_equal(np.einsum(a.T, [1,0], b.T, [1]), np.dot(b.T, a.T)) - a = np.arange(4*n, dtype=dtype).reshape(4,n) - b = np.arange(n, dtype=dtype) c = np.arange(4, dtype=dtype) np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe') assert_equal(c, np.dot(b.T.astype('f8'), a.T.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a.T, [1,0], b.T, [1], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(b.T.astype('f8'), + a.T.astype('f8')).astype(dtype)) # matmat(a,b) / a.dot(b) where a is matrix, b is matrix - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(30, dtype=dtype).reshape(5,6) - assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b)) + for n in range(1,17): + if n < 8 or dtype != 'f2': + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n*6, dtype=dtype).reshape(n,6) + assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b)) + assert_equal(np.einsum(a, [0,1], b, [1,2]), np.dot(a, b)) - a = np.arange(20, dtype=dtype).reshape(4,5) - b = np.arange(30, dtype=dtype).reshape(5,6) - c = np.arange(24, dtype=dtype).reshape(4,6) - np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe') - assert_equal(c, - np.dot(a.astype('f8'), b.astype('f8')).astype(dtype)) + for n in range(1,17): + a = np.arange(4*n, dtype=dtype).reshape(4,n) + b = np.arange(n*6, dtype=dtype).reshape(n,6) + c = np.arange(24, dtype=dtype).reshape(4,6) + np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1], b, [1,2], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, + np.dot(a.astype('f8'), + b.astype('f8')).astype(dtype)) # matrix triple product (note this is not currently an efficient # way to multiply 3 matrices) @@ -390,15 +482,19 @@ class TestEinSum(TestCase): if dtype != 'f2': assert_equal(np.einsum("ij,jk,kl", a, b, c), a.dot(b).dot(c)) + assert_equal(np.einsum(a, [0,1], b, [1,2], c, [2,3]), + a.dot(b).dot(c)) - a = np.arange(12, dtype=dtype).reshape(3,4) - b = np.arange(20, dtype=dtype).reshape(4,5) - c = np.arange(30, dtype=dtype).reshape(5,6) d = np.arange(18, dtype=dtype).reshape(3,6) np.einsum("ij,jk,kl", a, b, c, out=d, dtype='f8', casting='unsafe') assert_equal(d, a.astype('f8').dot(b.astype('f8') ).dot(c.astype('f8')).astype(dtype)) + d[...] = 0 + np.einsum(a, [0,1], b, [1,2], c, [2,3], out=d, + dtype='f8', casting='unsafe') + assert_equal(d, a.astype('f8').dot(b.astype('f8') + ).dot(c.astype('f8')).astype(dtype)) # tensordot(a, b) if np.dtype(dtype) != np.dtype('f2'): @@ -406,14 +502,19 @@ class TestEinSum(TestCase): b = np.arange(24, dtype=dtype).reshape(4,3,2) assert_equal(np.einsum("ijk, jil -> kl", a, b), np.tensordot(a,b, axes=([1,0],[0,1]))) + assert_equal(np.einsum(a, [0,1,2], b, [1,0,3], [2,3]), + np.tensordot(a,b, axes=([1,0],[0,1]))) - a = np.arange(60, dtype=dtype).reshape(3,4,5) - b = np.arange(24, dtype=dtype).reshape(4,3,2) c = np.arange(10, dtype=dtype).reshape(5,2) np.einsum("ijk,jil->kl", a, b, out=c, dtype='f8', casting='unsafe') assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), axes=([1,0],[0,1])).astype(dtype)) + c[...] = 0 + np.einsum(a, [0,1,2], b, [1,0,3], [2,3], out=c, + dtype='f8', casting='unsafe') + assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'), + axes=([1,0],[0,1])).astype(dtype)) finally: ctx.__exit__() @@ -424,10 +525,15 @@ class TestEinSum(TestCase): assert_equal(np.einsum("i,i,i->i", a, b, c, dtype='?', casting='unsafe'), logical_and(logical_and(a!=0, b!=0), c!=0)) + assert_equal(np.einsum(a, [0], b, [0], c, [0], [0], + dtype='?', casting='unsafe'), + logical_and(logical_and(a!=0, b!=0), c!=0)) a = np.arange(9, dtype=dtype) assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a)) + assert_equal(np.einsum(3, [], a, [0], []), 3*np.sum(a)) assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a)) + assert_equal(np.einsum(a, [0], 3, [], []), 3*np.sum(a)) # Various stride0, contiguous, and SSE aligned variants for n in range(1,25): @@ -451,10 +557,15 @@ class TestEinSum(TestCase): # An object array, summed as the data type a = np.arange(9, dtype=object) + b = np.einsum("i->", a, dtype=dtype, casting='unsafe') assert_equal(b, np.sum(a)) assert_equal(b.dtype, np.dtype(dtype)) + b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe') + assert_equal(b, np.sum(a)) + assert_equal(b.dtype, np.dtype(dtype)) + def test_einsum_sums_int8(self): self.check_einsum_sums('i1'); |