diff options
Diffstat (limited to 'Lib/decimal.py')
-rw-r--r-- | Lib/decimal.py | 162 |
1 files changed, 107 insertions, 55 deletions
diff --git a/Lib/decimal.py b/Lib/decimal.py index 5207a479ca..55faf993be 100644 --- a/Lib/decimal.py +++ b/Lib/decimal.py @@ -718,6 +718,39 @@ class Decimal(_numbers.Real, _numbers.Inexact): return other._fix_nan(context) return 0 + def _compare_check_nans(self, other, context): + """Version of _check_nans used for the signaling comparisons + compare_signal, __le__, __lt__, __ge__, __gt__. + + Signal InvalidOperation if either self or other is a (quiet + or signaling) NaN. Signaling NaNs take precedence over quiet + NaNs. + + Return 0 if neither operand is a NaN. + + """ + if context is None: + context = getcontext() + + if self._is_special or other._is_special: + if self.is_snan(): + return context._raise_error(InvalidOperation, + 'comparison involving sNaN', + self) + elif other.is_snan(): + return context._raise_error(InvalidOperation, + 'comparison involving sNaN', + other) + elif self.is_qnan(): + return context._raise_error(InvalidOperation, + 'comparison involving NaN', + self) + elif other.is_qnan(): + return context._raise_error(InvalidOperation, + 'comparison involving NaN', + other) + return 0 + def __bool__(self): """Return True if self is nonzero; otherwise return False. @@ -725,18 +758,13 @@ class Decimal(_numbers.Real, _numbers.Inexact): """ return self._is_special or self._int != '0' - def __cmp__(self, other): - other = _convert_other(other) - if other is NotImplemented: - # Never return NotImplemented - return 1 + def _cmp(self, other): + """Compare the two non-NaN decimal instances self and other. - if self._is_special or other._is_special: - # check for nans, without raising on a signaling nan - if self._isnan() or other._isnan(): - return 1 # Comparison involving NaN's always reports self > other + Returns -1 if self < other, 0 if self == other and 1 + if self > other. This routine is for internal use only.""" - # INF = INF + if self._is_special or other._is_special: return cmp(self._isinfinity(), other._isinfinity()) # check for zeros; note that cmp(0, -0) should return 0 @@ -765,35 +793,72 @@ class Decimal(_numbers.Real, _numbers.Inexact): else: # self_adjusted < other_adjusted return -((-1)**self._sign) + # Note: The Decimal standard doesn't cover rich comparisons for + # Decimals. In particular, the specification is silent on the + # subject of what should happen for a comparison involving a NaN. + # We take the following approach: + # + # == comparisons involving a NaN always return False + # != comparisons involving a NaN always return True + # <, >, <= and >= comparisons involving a (quiet or signaling) + # NaN signal InvalidOperation, and return False if the + # InvalidOperation is trapped. + # + # This behavior is designed to conform as closely as possible to + # that specified by IEEE 754. + def __eq__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) == 0 + other = _convert_other(other) + if other is NotImplemented: + return other + if self.is_nan() or other.is_nan(): + return False + return self._cmp(other) == 0 def __ne__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) != 0 + other = _convert_other(other) + if other is NotImplemented: + return other + if self.is_nan() or other.is_nan(): + return True + return self._cmp(other) != 0 - def __lt__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) < 0 - def __le__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) <= 0 + def __lt__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) < 0 - def __gt__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) > 0 + def __le__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) <= 0 - def __ge__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) >= 0 + def __gt__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) > 0 + + def __ge__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) >= 0 def compare(self, other, context=None): """Compares one to another. @@ -812,7 +877,7 @@ class Decimal(_numbers.Real, _numbers.Inexact): if ans: return ans - return Decimal(self.__cmp__(other)) + return Decimal(self._cmp(other)) def __hash__(self): """x.__hash__() <==> hash(x)""" @@ -2463,7 +2528,7 @@ class Decimal(_numbers.Real, _numbers.Inexact): return other._fix_nan(context) return self._check_nans(other, context) - c = self.__cmp__(other) + c = self._cmp(other) if c == 0: # If both operands are finite and equal in numerical value # then an ordering is applied: @@ -2505,7 +2570,7 @@ class Decimal(_numbers.Real, _numbers.Inexact): return other._fix_nan(context) return self._check_nans(other, context) - c = self.__cmp__(other) + c = self._cmp(other) if c == 0: c = self.compare_total(other) @@ -2553,23 +2618,10 @@ class Decimal(_numbers.Real, _numbers.Inexact): It's pretty much like compare(), but all NaNs signal, with signaling NaNs taking precedence over quiet NaNs. """ - if context is None: - context = getcontext() - - self_is_nan = self._isnan() - other_is_nan = other._isnan() - if self_is_nan == 2: - return context._raise_error(InvalidOperation, 'sNaN', - self) - if other_is_nan == 2: - return context._raise_error(InvalidOperation, 'sNaN', - other) - if self_is_nan: - return context._raise_error(InvalidOperation, 'NaN in compare_signal', - self) - if other_is_nan: - return context._raise_error(InvalidOperation, 'NaN in compare_signal', - other) + other = _convert_other(other, raiseit = True) + ans = self._compare_check_nans(other, context) + if ans: + return ans return self.compare(other, context=context) def compare_total(self, other): @@ -3076,7 +3128,7 @@ class Decimal(_numbers.Real, _numbers.Inexact): return other._fix_nan(context) return self._check_nans(other, context) - c = self.copy_abs().__cmp__(other.copy_abs()) + c = self.copy_abs()._cmp(other.copy_abs()) if c == 0: c = self.compare_total(other) @@ -3106,7 +3158,7 @@ class Decimal(_numbers.Real, _numbers.Inexact): return other._fix_nan(context) return self._check_nans(other, context) - c = self.copy_abs().__cmp__(other.copy_abs()) + c = self.copy_abs()._cmp(other.copy_abs()) if c == 0: c = self.compare_total(other) @@ -3181,7 +3233,7 @@ class Decimal(_numbers.Real, _numbers.Inexact): if ans: return ans - comparison = self.__cmp__(other) + comparison = self._cmp(other) if comparison == 0: return self.copy_sign(other) |