From 7e1350d016684adbb13bd13ca8a62d0a716e71d0 Mon Sep 17 00:00:00 2001 From: Travis Oliphant Date: Fri, 11 Apr 2008 06:53:49 +0000 Subject: Fixed #728 scalar coercion problem with mixed types and r_ --- numpy/lib/index_tricks.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'numpy/lib/index_tricks.py') diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c45148057..22b8ef388 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,8 @@ __all__ = ['unravel_index', import sys import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, ScalarType, array +from numpy.core.numeric import asarray, ScalarType, array, dtype +from numpy.core.numerictypes import find_common_type import math import function_base @@ -225,7 +226,8 @@ class AxisConcatenator(object): key = (key,) objs = [] scalars = [] - final_dtypedescr = None + arraytypes = [] + scalartypes = [] for k in range(len(key)): scalar = False if type(key[k]) is slice: @@ -272,6 +274,7 @@ class AxisConcatenator(object): newobj = array(key[k],ndmin=ndmin) scalars.append(k) scalar = True + scalartypes.append(newobj.dtype) else: newobj = key[k] if ndmin > 1: @@ -289,14 +292,15 @@ class AxisConcatenator(object): newobj = newobj.transpose(axes) del tempobj objs.append(newobj) - if isinstance(newobj, _nx.ndarray) and not scalar: - if final_dtypedescr is None: - final_dtypedescr = newobj.dtype - elif newobj.dtype > final_dtypedescr: - final_dtypedescr = newobj.dtype - if final_dtypedescr is not None: + if not scalar and isinstance(newobj, _nx.ndarray): + arraytypes.append(newobj.dtype) + + # Esure that scalars won't up-cast unless warranted + final_dtype = find_common_type(arraytypes, scalartypes) + if final_dtype is not None: for k in scalars: - objs[k] = objs[k].astype(final_dtypedescr) + objs[k] = objs[k].astype(final_dtype) + res = _nx.concatenate(tuple(objs),axis=self.axis) return self._retval(res) -- cgit v1.2.1