summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-03-23 16:45:04 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-02-27 20:03:26 -0500
commit5626be617d9e5b3b8758b875adf7c7c2356bf9b3 (patch)
treec7993c02297d5252320e15afc393fabced5e4a84
parent042886c4369956db23d4d05a5cbe7baa4fb9f86b (diff)
downloadnumpy-5626be617d9e5b3b8758b875adf7c7c2356bf9b3.tar.gz
ENH: Implement axes keyword argument for gufuncs.
The axes argument allows one to specify the axes on which the gufunc will operate (by default, the trailing ones). It has to be a list with length equal to the number of operands, and each element a tuple of length equal to the number of core dimensions, with each element an axis index. If there is only one core dimension, the tuple can be replaced by a single index, and if none of the outputs have core dimensions, the corresponding empty tuples can be omitted.
-rw-r--r--numpy/core/src/umath/ufunc_object.c223
-rw-r--r--numpy/core/tests/test_ufunc.py93
2 files changed, 300 insertions, 16 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index c67f60752..1ffecb1a6 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -555,7 +555,8 @@ get_ufunc_arguments(PyUFuncObject *ufunc,
PyObject **out_extobj,
PyObject **out_typetup,
int *out_subok,
- PyArrayObject **out_wheremask)
+ PyArrayObject **out_wheremask,
+ PyObject **out_axes)
{
int i, nargs;
int nin = ufunc->nin;
@@ -570,6 +571,9 @@ get_ufunc_arguments(PyUFuncObject *ufunc,
*out_extobj = NULL;
*out_typetup = NULL;
+ if (out_axes != NULL) {
+ *out_axes = NULL;
+ }
if (out_wheremask != NULL) {
*out_wheremask = NULL;
}
@@ -806,6 +810,13 @@ get_ufunc_arguments(PyUFuncObject *ufunc,
}
switch (str[0]) {
+ case 'a':
+ /* possible axis argument for generalized ufunc */
+ if (out_axes != NULL && strcmp(str, "axes") == 0) {
+ *out_axes = value;
+ bad_arg = 0;
+ }
+ break;
case 'c':
/* Provides a policy for allowed casting */
if (strcmp(str, "casting") == 0) {
@@ -995,6 +1006,10 @@ fail:
*out_extobj = NULL;
Py_XDECREF(*out_typetup);
*out_typetup = NULL;
+ if (out_axes != NULL) {
+ Py_XDECREF(*out_axes);
+ *out_axes = NULL;
+ }
if (out_wheremask != NULL) {
Py_XDECREF(*out_wheremask);
*out_wheremask = NULL;
@@ -1760,6 +1775,155 @@ make_arr_prep_args(npy_intp nin, PyObject *args, PyObject *kwds)
}
/*
+ * Check whether any of the outputs of a gufunc has core dimensions.
+ */
+static int
+_has_output_coredims(PyUFuncObject *ufunc) {
+ int i;
+ for (i = ufunc->nin; i < ufunc->nin + ufunc->nout; ++i) {
+ if (ufunc->core_num_dims[i] > 0) {
+ return 1;
+ }
+ }
+ return 0;
+}
+
+/*
+ * Interpret a possible axes keyword argument, using it to fill the remap_axis
+ * array which maps default to actual axes for each operand, indexed as
+ * as remap_axis[iop][iaxis]. The default axis order has first all broadcast
+ * axes and then the core axes the gufunc operates on.
+ *
+ * Returns 0 on success, and -1 on failure
+ */
+static int
+_parse_axes_arg(PyUFuncObject *ufunc, PyObject *axes, PyArrayObject **op,
+ int broadcast_ndim, int **remap_axis) {
+ int nin = ufunc->nin;
+ int nout = ufunc->nout;
+ int nop = nin + nout;
+ int iop, list_size;
+
+ if (!PyList_Check(axes)) {
+ PyErr_SetString(PyExc_TypeError, "axes should be a list.");
+ return -1;
+ }
+ list_size = PyList_Size(axes);
+ if (list_size != nop) {
+ if (list_size != nin || _has_output_coredims(ufunc)) {
+ PyErr_Format(PyExc_ValueError,
+ "axes should be a list with an entry for all "
+ "%d inputs and outputs; entries for outputs can only "
+ "be omitted if none of them has core axes.",
+ nop);
+ return -1;
+ }
+ for (iop = nin; iop < nop; iop++) {
+ remap_axis[iop] = NULL;
+ }
+ }
+ for (iop = 0; iop < list_size; ++iop) {
+ int op_ndim, op_ncore, op_nbroadcast;
+ int have_seen_axis[NPY_MAXDIMS] = {0};
+ PyObject *op_axes_tuple, *axis_item;
+ int axis, op_axis;
+
+ op_ncore = ufunc->core_num_dims[iop];
+ if (op[iop] != NULL) {
+ op_ndim = PyArray_NDIM(op[iop]);
+ op_nbroadcast = op_ndim - op_ncore;
+ }
+ else {
+ op_nbroadcast = broadcast_ndim;
+ op_ndim = broadcast_ndim + op_ncore;
+ }
+ /*
+ * Get axes tuple for operand. If not a tuple already, make it one if
+ * there is only one axis (its content is checked later).
+ */
+ op_axes_tuple = PyList_GET_ITEM(axes, iop);
+ if (PyTuple_Check(op_axes_tuple)) {
+ if (PyTuple_Size(op_axes_tuple) != op_ncore) {
+ if (op_ncore == 1) {
+ PyErr_Format(PyExc_ValueError,
+ "axes item %d should be a tuple with a "
+ "single element, or an integer", iop);
+ }
+ else {
+ PyErr_Format(PyExc_ValueError,
+ "axes item %d should be a tuple with %d "
+ "elements", iop, op_ncore);
+ }
+ return -1;
+ }
+ Py_INCREF(op_axes_tuple);
+ }
+ else if (op_ncore == 1) {
+ op_axes_tuple = PyTuple_Pack(1, op_axes_tuple);
+ if (op_axes_tuple == NULL) {
+ return -1;
+ }
+ }
+ else {
+ PyErr_Format(PyExc_TypeError, "axes item %d should be a tuple",
+ iop);
+ return -1;
+ }
+ /*
+ * Now create the remap, starting with the core dimensions, and then
+ * adding the remaining broadcast axes that are to be iterated over.
+ */
+ for (axis = op_nbroadcast; axis < op_ndim; axis++) {
+ axis_item = PyTuple_GET_ITEM(op_axes_tuple, axis - op_nbroadcast);
+ op_axis = PyArray_PyIntAsInt(axis_item);
+ if (error_converting(op_axis) ||
+ (check_and_adjust_axis(&op_axis, op_ndim) < 0)) {
+ Py_DECREF(op_axes_tuple);
+ return -1;
+ }
+ if (have_seen_axis[op_axis]) {
+ PyErr_Format(PyExc_ValueError,
+ "axes item %d has value %d repeated",
+ iop, op_axis);
+ Py_DECREF(op_axes_tuple);
+ return -1;
+ }
+ have_seen_axis[op_axis] = 1;
+ remap_axis[iop][axis] = op_axis;
+ }
+ Py_DECREF(op_axes_tuple);
+ /*
+ * Fill the op_nbroadcast=op_ndim-op_ncore axes not yet set,
+ * using have_seen_axis to skip over entries set above.
+ */
+ for (axis = 0, op_axis = 0; axis < op_nbroadcast; axis++) {
+ while (have_seen_axis[op_axis]) {
+ op_axis++;
+ }
+ remap_axis[iop][axis] = op_axis++;
+ }
+ /*
+ * Check whether we are actually remapping anything. Here,
+ * op_axis can only equal axis if all broadcast axes were the same
+ * (i.e., the while loop above was never entered).
+ */
+ if (axis == op_axis) {
+ while (axis < op_ndim && remap_axis[iop][axis] == axis) {
+ axis++;
+ }
+ }
+ if (axis == op_ndim) {
+ remap_axis[iop] = NULL;
+ }
+ } /* end of for(iop) loop over operands */
+ return 0;
+}
+
+#define REMAP_AXIS(iop, axis) ((remap_axis != NULL && \
+ remap_axis[iop] != NULL)? \
+ remap_axis[iop][axis] : axis)
+
+/*
* Validate the core dimensions of all the operands, and collect all of
* the labelled core dimensions into 'core_dim_sizes'.
*
@@ -1781,7 +1945,7 @@ make_arr_prep_args(npy_intp nin, PyObject *args, PyObject *kwds)
*/
static int
_get_coredim_sizes(PyUFuncObject *ufunc, PyArrayObject **op,
- npy_intp* core_dim_sizes) {
+ npy_intp* core_dim_sizes, int **remap_axis) {
int i;
int nin = ufunc->nin;
int nout = ufunc->nout;
@@ -1815,8 +1979,8 @@ _get_coredim_sizes(PyUFuncObject *ufunc, PyArrayObject **op,
*/
for (idim = 0; idim < num_dims; ++idim) {
int core_dim_index = ufunc->core_dim_ixs[dim_offset+idim];
- npy_intp op_dim_size =
- PyArray_DIM(op[i], core_start_dim+idim);
+ npy_intp op_dim_size = PyArray_DIM(
+ op[i], REMAP_AXIS(i, core_start_dim+idim));
if (core_dim_sizes[core_dim_index] == -1) {
core_dim_sizes[core_dim_index] = op_dim_size;
@@ -1950,7 +2114,9 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
/* The sizes of the core dimensions (# entries is ufunc->core_num_dim_ix) */
npy_intp *core_dim_sizes = inner_dimensions + 1;
int core_dim_ixs_size;
-
+ /* swapping around of axes */
+ int *remap_axis_memory = NULL;
+ int **remap_axis = NULL;
/* The __array_prepare__ function to call for each output */
PyObject *arr_prep[NPY_MAXARGS];
/*
@@ -1962,8 +2128,8 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
NPY_ORDER order = NPY_KEEPORDER;
/* Use the default assignment casting rule */
NPY_CASTING casting = NPY_DEFAULT_ASSIGN_CASTING;
- /* When provided, extobj and typetup contain borrowed references */
- PyObject *extobj = NULL, *type_tup = NULL;
+ /* When provided, extobj, typetup, and axes contain borrowed references */
+ PyObject *extobj = NULL, *type_tup = NULL, *axes = NULL;
if (ufunc == NULL) {
PyErr_SetString(PyExc_ValueError, "function not supported");
@@ -1990,7 +2156,7 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
/* Get all the arguments */
retval = get_ufunc_arguments(ufunc, args, kwds,
op, &order, &casting, &extobj,
- &type_tup, &subok, NULL);
+ &type_tup, &subok, NULL, &axes);
if (retval < 0) {
goto fail;
}
@@ -2026,8 +2192,30 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
goto fail;
}
+ /* Possibly remap axes. */
+ if (axes) {
+ /*
+ * possibly remap axes, using newly allocated memory.
+ */
+ remap_axis = PyArray_malloc(sizeof(remap_axis[0]) * nop);
+ remap_axis_memory = PyArray_malloc(sizeof(remap_axis_memory[0]) *
+ nop * NPY_MAXDIMS);
+ if (remap_axis == NULL || remap_axis_memory == NULL) {
+ PyErr_NoMemory();
+ goto fail;
+ }
+ for (i=0; i < nop; i++) {
+ remap_axis[i] = remap_axis_memory + i * NPY_MAXDIMS;
+ }
+ retval = _parse_axes_arg(ufunc, axes, op, broadcast_ndim,
+ remap_axis);
+ if(retval < 0) {
+ goto fail;
+ }
+ } /* end of if(axis) */
+
/* Collect the lengths of the labelled core dimensions */
- retval = _get_coredim_sizes(ufunc, op, core_dim_sizes);
+ retval = _get_coredim_sizes(ufunc, op, core_dim_sizes, remap_axis);
if(retval < 0) {
goto fail;
}
@@ -2054,7 +2242,8 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
/* Broadcast all the unspecified dimensions normally */
for (idim = 0; idim < broadcast_ndim; ++idim) {
if (idim >= broadcast_ndim - n) {
- op_axes_arrays[i][idim] = idim - (broadcast_ndim - n);
+ op_axes_arrays[i][idim] =
+ REMAP_AXIS(i, idim - (broadcast_ndim - n));
}
else {
op_axes_arrays[i][idim] = -1;
@@ -2074,7 +2263,7 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
for (idim = 0; idim < num_dims; ++idim) {
iter_shape[j] = core_dim_sizes[
ufunc->core_dim_ixs[dim_offset + idim]];
- op_axes_arrays[i][j] = n + idim;
+ op_axes_arrays[i][j] = REMAP_AXIS(i, n + idim);
++j;
}
}
@@ -2220,11 +2409,12 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
for (j = 0; j < num_dims; ++j) {
if (core_start_dim + j >= 0) {
/*
- * Force the stride to zero when the shape is 1, sot
+ * Force the stride to zero when the shape is 1, so
* that the broadcasting works right.
*/
- if (shape[core_start_dim + j] != 1) {
- inner_strides[idim++] = strides[core_start_dim + j];
+ int remapped_axis = REMAP_AXIS(i, core_start_dim + j);
+ if (shape[remapped_axis] != 1) {
+ inner_strides[idim++] = strides[remapped_axis];
} else {
inner_strides[idim++] = 0;
}
@@ -2375,7 +2565,8 @@ fail:
}
Py_XDECREF(type_tup);
Py_XDECREF(arr_prep_args);
-
+ PyArray_free(remap_axis_memory);
+ PyArray_free(remap_axis);
return retval;
}
@@ -2450,7 +2641,7 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
/* Get all the arguments */
retval = get_ufunc_arguments(ufunc, args, kwds,
op, &order, &casting, &extobj,
- &type_tup, &subok, &wheremask);
+ &type_tup, &subok, &wheremask, NULL);
if (retval < 0) {
goto fail;
}
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 57e0ec272..239e8bc64 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -597,6 +597,99 @@ class TestUfunc(object):
umt.inner1d(a, b, out=c[..., 0])
assert_array_equal(c[..., 0], np.sum(a*b, axis=-1), err_msg=msg)
+ def test_axes_argument(self):
+ # inner1d signature: '(i),(i)->()'
+ in1d = umt.inner1d
+ a = np.arange(27.).reshape((3, 3, 3))
+ b = np.arange(10., 19.).reshape((3, 1, 3))
+ # basic tests on inputs (outputs tested below with matrix_multiply).
+ c = in1d(a, b)
+ assert_array_equal(c, (a * b).sum(-1))
+ # default
+ c = in1d(a, b, axes=[(-1,), (-1,), ()])
+ assert_array_equal(c, (a * b).sum(-1))
+ # integers ok for single axis.
+ c = in1d(a, b, axes=[-1, -1, ()])
+ assert_array_equal(c, (a * b).sum(-1))
+ # mix fine
+ c = in1d(a, b, axes=[(-1,), -1, ()])
+ assert_array_equal(c, (a * b).sum(-1))
+ # can omit last axis.
+ c = in1d(a, b, axes=[-1, -1])
+ assert_array_equal(c, (a * b).sum(-1))
+ # can pass in other types of integer (with __index__ protocol)
+ c = in1d(a, b, axes=[np.int8(-1), np.array(-1, dtype=np.int32)])
+ assert_array_equal(c, (a * b).sum(-1))
+ # swap some axes
+ c = in1d(a, b, axes=[0, 0])
+ assert_array_equal(c, (a * b).sum(0))
+ c = in1d(a, b, axes=[0, 2])
+ assert_array_equal(c, (a.transpose(1, 2, 0) * b).sum(-1))
+ # Check errors for inproperly constructed axes arguments.
+ # should have list.
+ assert_raises(TypeError, in1d, a, b, axes=-1)
+ # needs enough elements
+ assert_raises(ValueError, in1d, a, b, axes=[-1])
+ # should pass in indices.
+ assert_raises(TypeError, in1d, a, b, axes=[-1.0, -1.0])
+ assert_raises(TypeError, in1d, a, b, axes=[(-1.0,), -1])
+ assert_raises(TypeError, in1d, a, b, axes=[None, 1])
+ # cannot pass an index unless there is only one dimension
+ # (output is wrong in this case)
+ assert_raises(TypeError, in1d, a, b, axes=[-1, -1, -1])
+ # or pass in generally the wrong number of axes
+ assert_raises(ValueError, in1d, a, b, axes=[-1, -1, (-1,)])
+ assert_raises(ValueError, in1d, a, b, axes=[-1, (-2, -1), ()])
+ # axes need to have same length.
+ assert_raises(ValueError, in1d, a, b, axes=[0, 1])
+
+ # matrix_multiply signature: '(m,n),(n,p)->(m,p)'
+ mm = umt.matrix_multiply
+ a = np.arange(12).reshape((2, 3, 2))
+ b = np.arange(8).reshape((2, 2, 2, 1)) + 1
+ # Sanity check.
+ c = mm(a, b)
+ assert_array_equal(c, np.matmul(a, b))
+ # Default axes.
+ c = mm(a, b, axes=[(-2, -1), (-2, -1), (-2, -1)])
+ assert_array_equal(c, np.matmul(a, b))
+ # Default with explicit axes.
+ c = mm(a, b, axes=[(1, 2), (2, 3), (2, 3)])
+ assert_array_equal(c, np.matmul(a, b))
+ # swap some axes.
+ c = mm(a, b, axes=[(0, -1), (1, 2), (-2, -1)])
+ assert_array_equal(c, np.matmul(a.transpose(1, 0, 2),
+ b.transpose(0, 3, 1, 2)))
+ # Default with output array.
+ c = np.empty((2, 2, 3, 1))
+ d = mm(a, b, out=c, axes=[(1, 2), (2, 3), (2, 3)])
+ assert_(c is d)
+ assert_array_equal(c, np.matmul(a, b))
+ # Transposed output array
+ c = np.empty((1, 2, 2, 3))
+ d = mm(a, b, out=c, axes=[(-2, -1), (-2, -1), (3, 0)])
+ assert_(c is d)
+ assert_array_equal(c, np.matmul(a, b).transpose(3, 0, 1, 2))
+ # Check errors for inproperly constructed axes arguments.
+ # wrong argument
+ assert_raises(TypeError, mm, a, b, axis=1)
+ # axes should be list
+ assert_raises(TypeError, mm, a, b, axes=1)
+ assert_raises(TypeError, mm, a, b, axes=((-2, -1), (-2, -1), (-2, -1)))
+ # list needs to have right length
+ assert_raises(ValueError, mm, a, b, axes=[])
+ assert_raises(ValueError, mm, a, b, axes=[(-2, -1)])
+ # list should contain tuples for multiple axes
+ assert_raises(TypeError, mm, a, b, axes=[-1, -1, -1])
+ assert_raises(TypeError, mm, a, b, axes=[(-2, -1), (-2, -1), -1])
+ assert_raises(TypeError,
+ mm, a, b, axes=[[-2, -1], [-2, -1], [-2, -1]])
+ assert_raises(TypeError,
+ mm, a, b, axes=[(-2, -1), (-2, -1), [-2, -1]])
+ assert_raises(TypeError, mm, a, b, axes=[(-2, -1), (-2, -1), None])
+ # tuples should not have duplicated values
+ assert_raises(ValueError, mm, a, b, axes=[(-2, -1), (-2, -1), (-2, -2)])
+
def test_innerwt(self):
a = np.arange(6).reshape((2, 3))
b = np.arange(10, 16).reshape((2, 3))