summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2008-04-11 06:53:49 +0000
committerTravis Oliphant <oliphant@enthought.com>2008-04-11 06:53:49 +0000
commit7e1350d016684adbb13bd13ca8a62d0a716e71d0 (patch)
tree2045108f450fcee29ece6594165cd2e697c69ea2 /numpy/lib/index_tricks.py
parentcb7f01c247002a6921d9a247d60bacf2db3cd39b (diff)
downloadnumpy-7e1350d016684adbb13bd13ca8a62d0a716e71d0.tar.gz
Fixed #728 scalar coercion problem with mixed types and r_
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index c45148057..22b8ef388 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -7,7 +7,8 @@ __all__ = ['unravel_index',
import sys
import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType, array
+from numpy.core.numeric import asarray, ScalarType, array, dtype
+from numpy.core.numerictypes import find_common_type
import math
import function_base
@@ -225,7 +226,8 @@ class AxisConcatenator(object):
key = (key,)
objs = []
scalars = []
- final_dtypedescr = None
+ arraytypes = []
+ scalartypes = []
for k in range(len(key)):
scalar = False
if type(key[k]) is slice:
@@ -272,6 +274,7 @@ class AxisConcatenator(object):
newobj = array(key[k],ndmin=ndmin)
scalars.append(k)
scalar = True
+ scalartypes.append(newobj.dtype)
else:
newobj = key[k]
if ndmin > 1:
@@ -289,14 +292,15 @@ class AxisConcatenator(object):
newobj = newobj.transpose(axes)
del tempobj
objs.append(newobj)
- if isinstance(newobj, _nx.ndarray) and not scalar:
- if final_dtypedescr is None:
- final_dtypedescr = newobj.dtype
- elif newobj.dtype > final_dtypedescr:
- final_dtypedescr = newobj.dtype
- if final_dtypedescr is not None:
+ if not scalar and isinstance(newobj, _nx.ndarray):
+ arraytypes.append(newobj.dtype)
+
+ # Esure that scalars won't up-cast unless warranted
+ final_dtype = find_common_type(arraytypes, scalartypes)
+ if final_dtype is not None:
for k in scalars:
- objs[k] = objs[k].astype(final_dtypedescr)
+ objs[k] = objs[k].astype(final_dtype)
+
res = _nx.concatenate(tuple(objs),axis=self.axis)
return self._retval(res)