summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-04-28 11:50:19 +0200
committerGitHub <noreply@github.com>2023-04-28 11:50:19 +0200
commit94e572399754c48af364a669cf28422ded5f6bec (patch)
tree03ee8b130e0f7b430d6b3cdfe915297d26be849f
parentde702b37336781b504d5b13d720cb4ebbbbc97db (diff)
parent962120be1b0e5f1a4015292d763de1c109aaf05c (diff)
downloadnumpy-94e572399754c48af364a669cf28422ded5f6bec.tar.gz
Merge pull request #18053 from Iamsoto/adding_object_to_einsum
ENH: Adding Object dtype to einsum
-rw-r--r--doc/release/upcoming_changes/18053.new_feature.rst4
-rw-r--r--numpy/core/src/multiarray/convert.c33
-rw-r--r--numpy/core/src/multiarray/einsum.c.src39
-rw-r--r--numpy/core/src/multiarray/einsum_sumprod.c.src57
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c5
-rw-r--r--numpy/core/tests/test_einsum.py141
-rw-r--r--numpy/core/tests/test_regression.py5
7 files changed, 235 insertions, 49 deletions
diff --git a/doc/release/upcoming_changes/18053.new_feature.rst b/doc/release/upcoming_changes/18053.new_feature.rst
new file mode 100644
index 000000000..fea04f79a
--- /dev/null
+++ b/doc/release/upcoming_changes/18053.new_feature.rst
@@ -0,0 +1,4 @@
+``np.einsum`` now accepts arrays with ``object`` dtype
+------------------------------------------------------
+The code path will call python operators on object dtype arrays, much
+like ``np.dot`` and ``np.matmul``.
diff --git a/numpy/core/src/multiarray/convert.c b/numpy/core/src/multiarray/convert.c
index e8b880a43..7ef80cf28 100644
--- a/numpy/core/src/multiarray/convert.c
+++ b/numpy/core/src/multiarray/convert.c
@@ -424,7 +424,9 @@ PyArray_FillWithScalar(PyArrayObject *arr, PyObject *obj)
}
/*
- * Fills an array with zeros.
+ * Internal function to fill an array with zeros.
+ * Used in einsum and dot, which ensures the dtype is, in some sense, numerical
+ * and not a str or struct
*
* dst: The destination array.
* wheremask: If non-NULL, a boolean mask specifying where to set the values.
@@ -435,21 +437,26 @@ NPY_NO_EXPORT int
PyArray_AssignZero(PyArrayObject *dst,
PyArrayObject *wheremask)
{
- npy_bool value;
- PyArray_Descr *bool_dtype;
- int retcode;
-
- /* Create a raw bool scalar with the value False */
- bool_dtype = PyArray_DescrFromType(NPY_BOOL);
- if (bool_dtype == NULL) {
- return -1;
+ int retcode = 0;
+ if (PyArray_ISOBJECT(dst)) {
+ PyObject * pZero = PyLong_FromLong(0);
+ retcode = PyArray_AssignRawScalar(dst, PyArray_DESCR(dst),
+ (char *)&pZero, wheremask, NPY_SAFE_CASTING);
+ Py_DECREF(pZero);
}
- value = 0;
+ else {
+ /* Create a raw bool scalar with the value False */
+ PyArray_Descr *bool_dtype = PyArray_DescrFromType(NPY_BOOL);
+ if (bool_dtype == NULL) {
+ return -1;
+ }
+ npy_bool value = 0;
- retcode = PyArray_AssignRawScalar(dst, bool_dtype, (char *)&value,
- wheremask, NPY_SAFE_CASTING);
+ retcode = PyArray_AssignRawScalar(dst, bool_dtype, (char *)&value,
+ wheremask, NPY_SAFE_CASTING);
- Py_DECREF(bool_dtype);
+ Py_DECREF(bool_dtype);
+ }
return retcode;
}
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 9f9d41579..0b3b3fd8c 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -17,6 +17,7 @@
#include <numpy/npy_common.h>
#include <numpy/arrayobject.h>
#include <npy_pycompat.h>
+#include <array_assign.h> //PyArray_AssignRawScalar
#include <ctype.h>
@@ -531,10 +532,17 @@ unbuffered_loop_nop1_ndim2(NpyIter *iter)
* Since the iterator wasn't tracking coordinates, the
* loop provided by the iterator is in Fortran-order.
*/
- NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
+ int needs_api = NpyIter_IterationNeedsAPI(iter);
+ if (!needs_api) {
+ NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
+ }
for (coord = shape[1]; coord > 0; --coord) {
sop(1, ptrs[0], strides[0], shape[0]);
+ if (needs_api && PyErr_Occurred()){
+ return -1;
+ }
+
ptr = ptrs[1][0] + strides[1][0];
ptrs[0][0] = ptrs[1][0] = ptr;
ptr = ptrs[1][1] + strides[1][1];
@@ -585,11 +593,18 @@ unbuffered_loop_nop1_ndim3(NpyIter *iter)
* Since the iterator wasn't tracking coordinates, the
* loop provided by the iterator is in Fortran-order.
*/
- NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
+ int needs_api = NpyIter_IterationNeedsAPI(iter);
+ if (!needs_api) {
+ NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
+ }
for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) {
for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) {
sop(1, ptrs[0], strides[0], shape[0]);
+ if (needs_api && PyErr_Occurred()){
+ return -1;
+ }
+
ptr = ptrs[1][0] + strides[1][0];
ptrs[0][0] = ptrs[1][0] = ptr;
ptr = ptrs[1][1] + strides[1][1];
@@ -642,10 +657,17 @@ unbuffered_loop_nop2_ndim2(NpyIter *iter)
* Since the iterator wasn't tracking coordinates, the
* loop provided by the iterator is in Fortran-order.
*/
- NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
+ int needs_api = NpyIter_IterationNeedsAPI(iter);
+ if (!needs_api) {
+ NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
+ }
for (coord = shape[1]; coord > 0; --coord) {
sop(2, ptrs[0], strides[0], shape[0]);
+ if(needs_api && PyErr_Occurred()){
+ return -1;
+ }
+
ptr = ptrs[1][0] + strides[1][0];
ptrs[0][0] = ptrs[1][0] = ptr;
ptr = ptrs[1][1] + strides[1][1];
@@ -698,11 +720,18 @@ unbuffered_loop_nop2_ndim3(NpyIter *iter)
* Since the iterator wasn't tracking coordinates, the
* loop provided by the iterator is in Fortran-order.
*/
- NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
+ int needs_api = NpyIter_IterationNeedsAPI(iter);
+ if (!needs_api) {
+ NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
+ }
for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) {
for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) {
sop(2, ptrs[0], strides[0], shape[0]);
+ if(needs_api && PyErr_Occurred()){
+ return -1;
+ }
+
ptr = ptrs[1][0] + strides[1][0];
ptrs[0][0] = ptrs[1][0] = ptr;
ptr = ptrs[1][1] + strides[1][1];
@@ -1024,7 +1053,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
goto fail;
}
- /* Initialize the output to all zeros */
+ /* Initialize the output to all zeros or None*/
ret = NpyIter_GetOperandArray(iter)[nop];
if (PyArray_AssignZero(ret, NULL) < 0) {
goto fail;
diff --git a/numpy/core/src/multiarray/einsum_sumprod.c.src b/numpy/core/src/multiarray/einsum_sumprod.c.src
index e7b2f2c2c..5f2ccf2dc 100644
--- a/numpy/core/src/multiarray/einsum_sumprod.c.src
+++ b/numpy/core/src/multiarray/einsum_sumprod.c.src
@@ -1026,6 +1026,57 @@ bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
/**end repeat**/
+/**begin repeat
+ * #fn_name =
+ * object_sum_of_products_any,
+ * object_sum_of_products_one,
+ * object_sum_of_products_two,
+ * object_sum_of_products_three,
+ * object_sum_of_products_contig_any,
+ * object_sum_of_products_contig_one,
+ * object_sum_of_products_contig_two,
+ * object_sum_of_products_contig_three,
+ * object_sum_of_products_outstride0_any,
+ * object_sum_of_products_outstride0_one,
+ * object_sum_of_products_outstride0_two,
+ * object_sum_of_products_outstride0_three#
+ */
+static void
+@fn_name@(int nop, char **dataptr,
+ npy_intp const *strides, npy_intp count)
+{
+ while(count--){
+ PyObject *prod = *(PyObject **)dataptr[0];
+ if (!prod) {
+ prod = Py_None; // convention is to treat nulls as None
+ }
+ Py_INCREF(prod);
+ for (int i = 1; i < nop; ++i){
+ PyObject *curr = *(PyObject **)dataptr[i];
+ if (!curr) {
+ curr = Py_None; // convention is to treat nulls as None
+ }
+ Py_SETREF(prod, PyNumber_Multiply(prod, curr));
+ if (!prod) {
+ return;
+ }
+ }
+
+ PyObject *sum = PyNumber_Add(*(PyObject **)dataptr[nop], prod);
+ Py_DECREF(prod);
+ if (!sum) {
+ return;
+ }
+
+ Py_XDECREF(*(PyObject **)dataptr[nop]);
+ *(PyObject **)dataptr[nop] = sum;
+ for (int i = 0; i <= nop; ++i) {
+ dataptr[i] += strides[i];
+ }
+ }
+}
+/**end repeat**/
+
/* These tables need to match up with the type enum */
static sum_of_products_fn
_contig_outstride0_unary_specialization_table[NPY_NTYPES] = {
@@ -1116,7 +1167,7 @@ static sum_of_products_fn _outstride0_specialized_table[NPY_NTYPES][4] = {
* 1, 1,
* 1, 1, 1,
* 1, 1, 1,
- * 0, 0, 0, 0,
+ * 1, 0, 0, 0,
* 0, 0, 1#
*/
#if @use@
@@ -1152,7 +1203,7 @@ static sum_of_products_fn _allcontig_specialized_table[NPY_NTYPES][4] = {
* 1, 1,
* 1, 1, 1,
* 1, 1, 1,
- * 0, 0, 0, 0,
+ * 1, 0, 0, 0,
* 0, 0, 1#
*/
#if @use@
@@ -1188,7 +1239,7 @@ static sum_of_products_fn _unspecialized_table[NPY_NTYPES][4] = {
* 1, 1,
* 1, 1, 1,
* 1, 1, 1,
- * 0, 0, 0, 0,
+ * 1, 0, 0, 0,
* 0, 0, 1#
*/
#if @use@
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 029e41990..d7d19493b 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -66,6 +66,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "compiled_base.h"
#include "mem_overlap.h"
#include "typeinfo.h"
+#include "convert.h" /* for PyArray_AssignZero */
#include "get_attr_string.h"
#include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */
@@ -1084,7 +1085,9 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
}
/* Ensure that multiarray.dot(<Nx0>,<0xM>) -> zeros((N,M)) */
if (PyArray_SIZE(ap1) == 0 && PyArray_SIZE(ap2) == 0) {
- memset(PyArray_DATA(out_buf), 0, PyArray_NBYTES(out_buf));
+ if (PyArray_AssignZero(out_buf, NULL) < 0) {
+ goto fail;
+ }
}
dot = PyArray_DESCR(out_buf)->f->dotfunc;
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 043785782..3a06d119f 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -100,6 +100,69 @@ class TestEinsum:
assert_raises(ValueError, np.einsum, "i->i", np.arange(6).reshape(-1, 1),
optimize=do_opt, order='d')
+ def test_einsum_object_errors(self):
+ # Exceptions created by object arithmetic should
+ # successfully propogate
+
+ class CustomException(Exception):
+ pass
+
+ class DestructoBox:
+
+ def __init__(self, value, destruct):
+ self._val = value
+ self._destruct = destruct
+
+ def __add__(self, other):
+ tmp = self._val + other._val
+ if tmp >= self._destruct:
+ raise CustomException
+ else:
+ self._val = tmp
+ return self
+
+ def __radd__(self, other):
+ if other == 0:
+ return self
+ else:
+ return self.__add__(other)
+
+ def __mul__(self, other):
+ tmp = self._val * other._val
+ if tmp >= self._destruct:
+ raise CustomException
+ else:
+ self._val = tmp
+ return self
+
+ def __rmul__(self, other):
+ if other == 0:
+ return self
+ else:
+ return self.__mul__(other)
+
+ a = np.array([DestructoBox(i, 5) for i in range(1, 10)],
+ dtype='object').reshape(3, 3)
+
+ # raised from unbuffered_loop_nop1_ndim2
+ assert_raises(CustomException, np.einsum, "ij->i", a)
+
+ # raised from unbuffered_loop_nop1_ndim3
+ b = np.array([DestructoBox(i, 100) for i in range(0, 27)],
+ dtype='object').reshape(3, 3, 3)
+ assert_raises(CustomException, np.einsum, "i...k->...", b)
+
+ # raised from unbuffered_loop_nop2_ndim2
+ b = np.array([DestructoBox(i, 55) for i in range(1, 4)],
+ dtype='object')
+ assert_raises(CustomException, np.einsum, "ij, j", a, b)
+
+ # raised from unbuffered_loop_nop2_ndim3
+ assert_raises(CustomException, np.einsum, "ij, jh", a, a)
+
+ # raised from PyArray_EinsteinSum
+ assert_raises(CustomException, np.einsum, "ij->", a)
+
def test_einsum_views(self):
# pass-through
for do_opt in [True, False]:
@@ -247,47 +310,50 @@ class TestEinsum:
# sum(a, axis=-1)
for n in range(1, 17):
a = np.arange(n, dtype=dtype)
- assert_equal(np.einsum("i->", a, optimize=do_opt),
- np.sum(a, axis=-1).astype(dtype))
- assert_equal(np.einsum(a, [0], [], optimize=do_opt),
- np.sum(a, axis=-1).astype(dtype))
+ b = np.sum(a, axis=-1)
+ if hasattr(b, 'astype'):
+ b = b.astype(dtype)
+ assert_equal(np.einsum("i->", a, optimize=do_opt), b)
+ assert_equal(np.einsum(a, [0], [], optimize=do_opt), b)
for n in range(1, 17):
a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
- assert_equal(np.einsum("...i->...", a, optimize=do_opt),
- np.sum(a, axis=-1).astype(dtype))
- assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt),
- np.sum(a, axis=-1).astype(dtype))
+ b = np.sum(a, axis=-1)
+ if hasattr(b, 'astype'):
+ b = b.astype(dtype)
+ assert_equal(np.einsum("...i->...", a, optimize=do_opt), b)
+ assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt), b)
# sum(a, axis=0)
for n in range(1, 17):
a = np.arange(2*n, dtype=dtype).reshape(2, n)
- assert_equal(np.einsum("i...->...", a, optimize=do_opt),
- np.sum(a, axis=0).astype(dtype))
- assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
- np.sum(a, axis=0).astype(dtype))
+ b = np.sum(a, axis=0)
+ if hasattr(b, 'astype'):
+ b = b.astype(dtype)
+ assert_equal(np.einsum("i...->...", a, optimize=do_opt), b)
+ assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt), b)
for n in range(1, 17):
a = np.arange(2*3*n, dtype=dtype).reshape(2, 3, n)
- assert_equal(np.einsum("i...->...", a, optimize=do_opt),
- np.sum(a, axis=0).astype(dtype))
- assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt),
- np.sum(a, axis=0).astype(dtype))
+ b = np.sum(a, axis=0)
+ if hasattr(b, 'astype'):
+ b = b.astype(dtype)
+ assert_equal(np.einsum("i...->...", a, optimize=do_opt), b)
+ assert_equal(np.einsum(a, [0, Ellipsis], [Ellipsis], optimize=do_opt), b)
# trace(a)
for n in range(1, 17):
a = np.arange(n*n, dtype=dtype).reshape(n, n)
- assert_equal(np.einsum("ii", a, optimize=do_opt),
- np.trace(a).astype(dtype))
- assert_equal(np.einsum(a, [0, 0], optimize=do_opt),
- np.trace(a).astype(dtype))
+ b = np.trace(a)
+ if hasattr(b, 'astype'):
+ b = b.astype(dtype)
+ assert_equal(np.einsum("ii", a, optimize=do_opt), b)
+ assert_equal(np.einsum(a, [0, 0], optimize=do_opt), b)
# gh-15961: should accept numpy int64 type in subscript list
np_array = np.asarray([0, 0])
- assert_equal(np.einsum(a, np_array, optimize=do_opt),
- np.trace(a).astype(dtype))
- assert_equal(np.einsum(a, list(np_array), optimize=do_opt),
- np.trace(a).astype(dtype))
+ assert_equal(np.einsum(a, np_array, optimize=do_opt), b)
+ assert_equal(np.einsum(a, list(np_array), optimize=do_opt), b)
# multiply(a, b)
assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case
@@ -489,11 +555,15 @@ class TestEinsum:
b = np.einsum("i->", a, dtype=dtype, casting='unsafe')
assert_equal(b, np.sum(a))
- assert_equal(b.dtype, np.dtype(dtype))
+ if hasattr(b, "dtype"):
+ # Can be a python object when dtype is object
+ assert_equal(b.dtype, np.dtype(dtype))
b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe')
assert_equal(b, np.sum(a))
- assert_equal(b.dtype, np.dtype(dtype))
+ if hasattr(b, "dtype"):
+ # Can be a python object when dtype is object
+ assert_equal(b.dtype, np.dtype(dtype))
# A case which was failing (ticket #1885)
p = np.arange(2) + 1
@@ -587,6 +657,10 @@ class TestEinsum:
def test_einsum_sums_clongdouble(self):
self.check_einsum_sums(np.clongdouble)
+ def test_einsum_sums_object(self):
+ self.check_einsum_sums('object')
+ self.check_einsum_sums('object', True)
+
def test_einsum_misc(self):
# This call used to crash because of a bug in
# PyArray_AssignZero
@@ -625,6 +699,21 @@ class TestEinsum:
# see issue gh-15776 and issue gh-15256
assert_equal(np.einsum('i,j', [1], [2], out=None), [[2]])
+ def test_object_loop(self):
+
+ class Mult:
+ def __mul__(self, other):
+ return 42
+
+ objMult = np.array([Mult()])
+ objNULL = np.ndarray(buffer = b'\0' * np.intp(0).itemsize, shape=1, dtype=object)
+
+ with pytest.raises(TypeError):
+ np.einsum("i,j", [1], objNULL)
+ with pytest.raises(TypeError):
+ np.einsum("i,j", objNULL, [1])
+ assert np.einsum("i,j", objMult, objMult) == 42
+
def test_subscript_range(self):
# Issue #7741, make sure that all letters of Latin alphabet (both uppercase & lowercase) can be used
# when creating a subscript from arrays
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index fdd536bb9..141636034 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1528,9 +1528,12 @@ class TestRegression:
for y in dtypes:
c = a.astype(y)
try:
- np.dot(b, c)
+ d = np.dot(b, c)
except TypeError:
failures.append((x, y))
+ else:
+ if d != 0:
+ failures.append((x, y))
if failures:
raise AssertionError("Failures: %r" % failures)