diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index abf9e1090..1da73dee5 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,7 @@ import numpy as np from .._utils import set_module import numpy.core.numeric as _nx from numpy.core.numeric import ScalarType, array -from numpy.core.numerictypes import find_common_type, issubdtype +from numpy.core.numerictypes import issubdtype import numpy.matrixlib as matrixlib from .function_base import diff @@ -342,9 +342,8 @@ class AxisConcatenator: axis = self.axis objs = [] - scalars = [] - arraytypes = [] - scalartypes = [] + # dtypes or scalars for weak scalar handling in result_type + result_type_objs = [] for k, item in enumerate(key): scalar = False @@ -390,10 +389,8 @@ class AxisConcatenator: except (ValueError, TypeError) as e: raise ValueError("unknown special directive") from e elif type(item) in ScalarType: - newobj = array(item, ndmin=ndmin) - scalars.append(len(objs)) scalar = True - scalartypes.append(newobj.dtype) + newobj = item else: item_ndim = np.ndim(item) newobj = array(item, copy=False, subok=True, ndmin=ndmin) @@ -405,15 +402,20 @@ class AxisConcatenator: defaxes = list(range(ndmin)) axes = defaxes[:k1] + defaxes[k2:] + defaxes[k1:k2] newobj = newobj.transpose(axes) + objs.append(newobj) - if not scalar and isinstance(newobj, _nx.ndarray): - arraytypes.append(newobj.dtype) - - # Ensure that scalars won't up-cast unless warranted - final_dtype = find_common_type(arraytypes, scalartypes) - if final_dtype is not None: - for k in scalars: - objs[k] = objs[k].astype(final_dtype) + if scalar: + result_type_objs.append(item) + else: + result_type_objs.append(newobj.dtype) + + # Ensure that scalars won't up-cast unless warranted, for 0, drops + # through to error in concatenate. + if len(result_type_objs) != 0: + final_dtype = _nx.result_type(*result_type_objs) + # concatenate could do cast, but that can be overriden: + objs = [array(obj, copy=False, subok=True, + ndmin=ndmin, dtype=final_dtype) for obj in objs] res = self.concatenate(tuple(objs), axis=axis) |