summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index abf9e1090..1da73dee5 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -7,7 +7,7 @@ import numpy as np
from .._utils import set_module
import numpy.core.numeric as _nx
from numpy.core.numeric import ScalarType, array
-from numpy.core.numerictypes import find_common_type, issubdtype
+from numpy.core.numerictypes import issubdtype
import numpy.matrixlib as matrixlib
from .function_base import diff
@@ -342,9 +342,8 @@ class AxisConcatenator:
axis = self.axis
objs = []
- scalars = []
- arraytypes = []
- scalartypes = []
+ # dtypes or scalars for weak scalar handling in result_type
+ result_type_objs = []
for k, item in enumerate(key):
scalar = False
@@ -390,10 +389,8 @@ class AxisConcatenator:
except (ValueError, TypeError) as e:
raise ValueError("unknown special directive") from e
elif type(item) in ScalarType:
- newobj = array(item, ndmin=ndmin)
- scalars.append(len(objs))
scalar = True
- scalartypes.append(newobj.dtype)
+ newobj = item
else:
item_ndim = np.ndim(item)
newobj = array(item, copy=False, subok=True, ndmin=ndmin)
@@ -405,15 +402,20 @@ class AxisConcatenator:
defaxes = list(range(ndmin))
axes = defaxes[:k1] + defaxes[k2:] + defaxes[k1:k2]
newobj = newobj.transpose(axes)
+
objs.append(newobj)
- if not scalar and isinstance(newobj, _nx.ndarray):
- arraytypes.append(newobj.dtype)
-
- # Ensure that scalars won't up-cast unless warranted
- final_dtype = find_common_type(arraytypes, scalartypes)
- if final_dtype is not None:
- for k in scalars:
- objs[k] = objs[k].astype(final_dtype)
+ if scalar:
+ result_type_objs.append(item)
+ else:
+ result_type_objs.append(newobj.dtype)
+
+ # Ensure that scalars won't up-cast unless warranted, for 0, drops
+ # through to error in concatenate.
+ if len(result_type_objs) != 0:
+ final_dtype = _nx.result_type(*result_type_objs)
+ # concatenate could do cast, but that can be overriden:
+ objs = [array(obj, copy=False, subok=True,
+ ndmin=ndmin, dtype=final_dtype) for obj in objs]
res = self.concatenate(tuple(objs), axis=axis)