diff options
author | Travis Oliphant <oliphant@enthought.com> | 2008-04-11 06:53:49 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2008-04-11 06:53:49 +0000 |
commit | 7e1350d016684adbb13bd13ca8a62d0a716e71d0 (patch) | |
tree | 2045108f450fcee29ece6594165cd2e697c69ea2 /numpy/lib/index_tricks.py | |
parent | cb7f01c247002a6921d9a247d60bacf2db3cd39b (diff) | |
download | numpy-7e1350d016684adbb13bd13ca8a62d0a716e71d0.tar.gz |
Fixed #728 scalar coercion problem with mixed types and r_
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c45148057..22b8ef388 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,8 @@ __all__ = ['unravel_index', import sys import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, ScalarType, array +from numpy.core.numeric import asarray, ScalarType, array, dtype +from numpy.core.numerictypes import find_common_type import math import function_base @@ -225,7 +226,8 @@ class AxisConcatenator(object): key = (key,) objs = [] scalars = [] - final_dtypedescr = None + arraytypes = [] + scalartypes = [] for k in range(len(key)): scalar = False if type(key[k]) is slice: @@ -272,6 +274,7 @@ class AxisConcatenator(object): newobj = array(key[k],ndmin=ndmin) scalars.append(k) scalar = True + scalartypes.append(newobj.dtype) else: newobj = key[k] if ndmin > 1: @@ -289,14 +292,15 @@ class AxisConcatenator(object): newobj = newobj.transpose(axes) del tempobj objs.append(newobj) - if isinstance(newobj, _nx.ndarray) and not scalar: - if final_dtypedescr is None: - final_dtypedescr = newobj.dtype - elif newobj.dtype > final_dtypedescr: - final_dtypedescr = newobj.dtype - if final_dtypedescr is not None: + if not scalar and isinstance(newobj, _nx.ndarray): + arraytypes.append(newobj.dtype) + + # Esure that scalars won't up-cast unless warranted + final_dtype = find_common_type(arraytypes, scalartypes) + if final_dtype is not None: for k in scalars: - objs[k] = objs[k].astype(final_dtypedescr) + objs[k] = objs[k].astype(final_dtype) + res = _nx.concatenate(tuple(objs),axis=self.axis) return self._retval(res) |