summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index c45148057..22b8ef388 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -7,7 +7,8 @@ __all__ = ['unravel_index',
import sys
import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType, array
+from numpy.core.numeric import asarray, ScalarType, array, dtype
+from numpy.core.numerictypes import find_common_type
import math
import function_base
@@ -225,7 +226,8 @@ class AxisConcatenator(object):
key = (key,)
objs = []
scalars = []
- final_dtypedescr = None
+ arraytypes = []
+ scalartypes = []
for k in range(len(key)):
scalar = False
if type(key[k]) is slice:
@@ -272,6 +274,7 @@ class AxisConcatenator(object):
newobj = array(key[k],ndmin=ndmin)
scalars.append(k)
scalar = True
+ scalartypes.append(newobj.dtype)
else:
newobj = key[k]
if ndmin > 1:
@@ -289,14 +292,15 @@ class AxisConcatenator(object):
newobj = newobj.transpose(axes)
del tempobj
objs.append(newobj)
- if isinstance(newobj, _nx.ndarray) and not scalar:
- if final_dtypedescr is None:
- final_dtypedescr = newobj.dtype
- elif newobj.dtype > final_dtypedescr:
- final_dtypedescr = newobj.dtype
- if final_dtypedescr is not None:
+ if not scalar and isinstance(newobj, _nx.ndarray):
+ arraytypes.append(newobj.dtype)
+
+ # Esure that scalars won't up-cast unless warranted
+ final_dtype = find_common_type(arraytypes, scalartypes)
+ if final_dtype is not None:
for k in scalars:
- objs[k] = objs[k].astype(final_dtypedescr)
+ objs[k] = objs[k].astype(final_dtype)
+
res = _nx.concatenate(tuple(objs),axis=self.axis)
return self._retval(res)