summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2022-11-07 12:08:10 +0100
committerSebastian Berg <sebastianb@nvidia.com>2022-11-07 12:08:10 +0100
commitb70a1e995faaf01cbf89cfe1316a7c14570cdbec (patch)
tree581efbb9b638a8d4e1b5ce614ba0c8e65460efc8 /numpy/lib/index_tricks.py
parentfcafb6560e37c948a594dce36d300888148bc599 (diff)
downloadnumpy-b70a1e995faaf01cbf89cfe1316a7c14570cdbec.tar.gz
MAINT: Refactor AxisConcatenator to not use find_common_type
Rather, use `result_type` instead. There are some exceedingly small theoretical changes, since `result_type` currently uses value-inspection logic. `find_common_type` did not, because it pre-dates the value inspection logic. (I.e. in theory, this switches it to value-based promotion, just to partially undo that in NEP 50; although more changes there.) The only place where it is fathomable to matter is if someone is using `np.c_[uint8_arr, -1]` to append 255 to an unsigned integer array.
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py31
1 files changed, 16 insertions, 15 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 4f414925d..4c407d643 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -7,7 +7,7 @@ import numpy.core.numeric as _nx
from numpy.core.numeric import (
asarray, ScalarType, array, alltrue, cumprod, arange, ndim
)
-from numpy.core.numerictypes import find_common_type, issubdtype
+from numpy.core.numerictypes import issubdtype
import numpy.matrixlib as matrixlib
from .function_base import diff
@@ -339,9 +339,8 @@ class AxisConcatenator:
axis = self.axis
objs = []
- scalars = []
- arraytypes = []
- scalartypes = []
+ # dtypes or scalars for weak scalar handling in result_type
+ result_type_objs = []
for k, item in enumerate(key):
scalar = False
@@ -387,10 +386,8 @@ class AxisConcatenator:
except (ValueError, TypeError) as e:
raise ValueError("unknown special directive") from e
elif type(item) in ScalarType:
- newobj = array(item, ndmin=ndmin)
- scalars.append(len(objs))
scalar = True
- scalartypes.append(newobj.dtype)
+ newobj = array(item, ndmin=ndmin)
else:
item_ndim = ndim(item)
newobj = array(item, copy=False, subok=True, ndmin=ndmin)
@@ -402,15 +399,19 @@ class AxisConcatenator:
defaxes = list(range(ndmin))
axes = defaxes[:k1] + defaxes[k2:] + defaxes[k1:k2]
newobj = newobj.transpose(axes)
+
objs.append(newobj)
- if not scalar and isinstance(newobj, _nx.ndarray):
- arraytypes.append(newobj.dtype)
-
- # Ensure that scalars won't up-cast unless warranted
- final_dtype = find_common_type(arraytypes, scalartypes)
- if final_dtype is not None:
- for k in scalars:
- objs[k] = objs[k].astype(final_dtype)
+ if scalar:
+ result_type_objs.append(item)
+ else:
+ result_type_objs.append(newobj.dtype)
+
+ # Ensure that scalars won't up-cast unless warranted, for 0, drops
+ # through to error in concatenate.
+ if len(result_type_objs) != 0:
+ final_dtype = _nx.result_type(*result_type_objs)
+ # concatenate could do cast, but that can be overriden:
+ objs = [obj.astype(final_dtype, copy=False) for obj in objs]
res = self.concatenate(tuple(objs), axis=axis)