summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/arraymethods.c11
-rw-r--r--numpy/lib/function_base.py17
2 files changed, 17 insertions, 11 deletions
diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c
index af4a6049e..a4fcc0f47 100644
--- a/numpy/core/src/arraymethods.c
+++ b/numpy/core/src/arraymethods.c
@@ -670,11 +670,12 @@ array_repeat(PyArrayObject *self, PyObject *args, PyObject *kwds) {
}
static char doc_choose[] = "a.choose(b0, b1, ..., bn)\n"\
- "\n"\
- "Return an array with elements chosen from 'a' at the positions\n"\
- "of the given arrays b_i. The array 'a' should be an integer array\n"\
- "with entries from 0 to n+1, and the b_i arrays should have the same\n"\
- "shape as 'a'.";
+ "\n" \
+ "Return an array that merges the b_i arrays together using 'a' as the index\n"
+ "The b_i arrays and 'a' must all be broadcastable to the same shape.\n"
+ "The output at a particular position is the input array b_i at that position\n"
+ "depending on the value of 'a' at that position. Therefore, 'a' must be\n"
+ "an integer array with entries from 0 to n+1.";
static PyObject *
array_choose(PyArrayObject *self, PyObject *args)
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 9969834a8..3143c8846 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -617,11 +617,15 @@ class vectorize(object):
"""
def __init__(self, pyfunc, otypes='', doc=None):
- nin, ndefault = _get_nargs(pyfunc)
self.thefunc = pyfunc
self.ufunc = None
- self.nin = nin
- self.nin_wo_defaults = nin - ndefault
+ nin, ndefault = _get_nargs(pyfunc)
+ if nin == 0 and ndefault == 0:
+ self.nin = None
+ self.nin_wo_defaults = None
+ else:
+ self.nin = nin
+ self.nin_wo_defaults = nin - ndefault
self.nout = None
if doc is None:
self.__doc__ = pyfunc.__doc__
@@ -640,9 +644,10 @@ class vectorize(object):
# get number of outputs and output types by calling
# the function on the first entries of args
nargs = len(args)
- if (nargs > self.nin) or (nargs < self.nin_wo_defaults):
- raise ValueError, "mismatch between python function inputs"\
- " and received arguments"
+ if self.nin:
+ if (nargs > self.nin) or (nargs < self.nin_wo_defaults):
+ raise ValueError, "mismatch between python function inputs"\
+ " and received arguments"
if self.nout is None or self.otypes == '':
newargs = []
for arg in args: