diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c380df0d8..b09538e42 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -172,7 +172,10 @@ class concatenator(object): if type(key) is not types.TupleType: key = (key,) objs = [] + scalars = [] + final_dtypedescr = None for k in range(len(key)): + scalar = False if type(key[k]) is types.SliceType: step = key[k].step start = key[k].start @@ -197,9 +200,19 @@ class concatenator(object): raise ValueError, "Unknown special directive." elif type(key[k]) in ScalarType: newobj = asarray([key[k]]) + scalars.append(k) + scalar = True else: newobj = key[k] objs.append(newobj) + if isinstance(newobj, _nx.ndarray) and not scalar: + if final_dtypedescr is None: + final_dtypedescr = newobj.dtypedescr + elif newobj.dtypedescr > final_dtypedescr: + final_dtypedescr = newobj.dtypedescr + if final_dtypedescr is not None: + for k in scalars: + objs[k] = objs[k].astype(final_dtypedescr) res = _nx.concatenate(tuple(objs),axis=self.axis) return self._retval(res) |