diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c45148057..22b8ef388 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,8 @@ __all__ = ['unravel_index', import sys import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, ScalarType, array +from numpy.core.numeric import asarray, ScalarType, array, dtype +from numpy.core.numerictypes import find_common_type import math import function_base @@ -225,7 +226,8 @@ class AxisConcatenator(object): key = (key,) objs = [] scalars = [] - final_dtypedescr = None + arraytypes = [] + scalartypes = [] for k in range(len(key)): scalar = False if type(key[k]) is slice: @@ -272,6 +274,7 @@ class AxisConcatenator(object): newobj = array(key[k],ndmin=ndmin) scalars.append(k) scalar = True + scalartypes.append(newobj.dtype) else: newobj = key[k] if ndmin > 1: @@ -289,14 +292,15 @@ class AxisConcatenator(object): newobj = newobj.transpose(axes) del tempobj objs.append(newobj) - if isinstance(newobj, _nx.ndarray) and not scalar: - if final_dtypedescr is None: - final_dtypedescr = newobj.dtype - elif newobj.dtype > final_dtypedescr: - final_dtypedescr = newobj.dtype - if final_dtypedescr is not None: + if not scalar and isinstance(newobj, _nx.ndarray): + arraytypes.append(newobj.dtype) + + # Esure that scalars won't up-cast unless warranted + final_dtype = find_common_type(arraytypes, scalartypes) + if final_dtype is not None: for k in scalars: - objs[k] = objs[k].astype(final_dtypedescr) + objs[k] = objs[k].astype(final_dtype) + res = _nx.concatenate(tuple(objs),axis=self.axis) return self._retval(res) |