diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index d7eaa75e5..d7e687eaa 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -764,11 +764,17 @@ class vectorize(object): Description: Define a vectorized function which takes nested sequence - objects or numpy arrays as inputs and returns a + of objects or numpy arrays as inputs and returns a numpy array as output, evaluating the function over successive tuples of the input arrays like the python map function except it uses the broadcasting rules of numpy. + Data-type of output of vectorized is determined by calling the function + with the first element of the input. This can be avoided by specifying + the otypes argument as either a string of typecode characters or a list + of data-types specifiers. There should be one data-type specifier for + each output. + Input: somefunction -- a Python function or method @@ -804,11 +810,13 @@ class vectorize(object): self.__doc__ = doc if isinstance(otypes, types.StringType): self.otypes = otypes + for char in self.otypes: + if char not in typecodes['All']: + raise ValueError, "invalid otype specified" + elif iterable(otypes): + self.otypes = ''.join([_nx.dtype(x).char for x in otypes]) else: - raise ValueError, "output types must be a string" - for char in self.otypes: - if char not in typecodes['All']: - raise ValueError, "invalid typecode specified" + raise ValueError, "output types must be a string of typecode characters or a list of data-types" self.lastcallargs = 0 def __call__(self, *args): @@ -835,10 +843,11 @@ class vectorize(object): else: self.nout = 1 theout = (theout,) - otypes = [] - for k in range(self.nout): - otypes.append(asarray(theout[k]).dtype.char) - self.otypes = ''.join(otypes) + if self.otypes == '': + otypes = [] + for k in range(self.nout): + otypes.append(asarray(theout[k]).dtype.char) + self.otypes = ''.join(otypes) if (self.ufunc is None): self.ufunc = frompyfunc(self.thefunc, nargs, self.nout) |