diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-04-26 10:34:46 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-04-26 15:40:18 -0400 |
commit | 6a0d61f12110624ad8709f67d4523e82bde262e5 (patch) | |
tree | ac1a7f60ab8e277224e0e1eedc01a4d6c3316b44 /lib/sqlalchemy | |
parent | 9f675fd042b05977f1b38887c2fbbb54ecd424f7 (diff) | |
download | sqlalchemy-6a0d61f12110624ad8709f67d4523e82bde262e5.tar.gz |
ensure correct cast for floats vs. numeric; other fixes
Fixed regression caused by the fix for :ticket:`9618` where floating point
values would lose precision being inserted in bulk, using either the
psycopg2 or psycopg drivers.
Implemented the :class:`_sqltypes.Double` type for SQL Server, having it
resolve to ``REAL``, or :class:`_mssql.REAL`, at DDL rendering time.
Fixed issue in Oracle dialects where ``Decimal`` returning types such as
:class:`_sqltypes.Numeric` would return floating point values, rather than
``Decimal`` objects, when these columns were used in the
:meth:`_dml.Insert.returning` clause to return INSERTed values.
Fixes: #9701
Change-Id: I8865496a6ccac6d44c19d0ca2a642b63c6172da9
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/cx_oracle.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/_psycopg_common.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_insert.py | 111 |
6 files changed, 150 insertions, 1 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index 0ef858a97..e6ad5e120 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -19,6 +19,7 @@ from .base import DATETIME from .base import DATETIME2 from .base import DATETIMEOFFSET from .base import DECIMAL +from .base import DOUBLE_PRECISION from .base import FLOAT from .base import IMAGE from .base import INTEGER diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 4a7e48ab8..e66d01a35 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1156,7 +1156,7 @@ RESERVED_WORDS = { class REAL(sqltypes.REAL): - __visit_name__ = "REAL" + """the SQL Server REAL datatype.""" def __init__(self, **kw): # REAL is a synonym for FLOAT(24) on SQL server. @@ -1166,6 +1166,21 @@ class REAL(sqltypes.REAL): super().__init__(**kw) +class DOUBLE_PRECISION(sqltypes.DOUBLE_PRECISION): + """the SQL Server DOUBLE PRECISION datatype. + + .. versionadded:: 2.0.11 + + """ + + def __init__(self, **kw): + # DOUBLE PRECISION is a synonym for FLOAT(53) on SQL server. + # it is only accepted as the word "DOUBLE PRECISION" in DDL, + # the numeric precision value is not allowed to be present + kw.setdefault("precision", 53) + super().__init__(**kw) + + class TINYINT(sqltypes.Integer): __visit_name__ = "TINYINT" @@ -1670,6 +1685,7 @@ ischema_names = { "varbinary": VARBINARY, "bit": BIT, "real": REAL, + "double precision": DOUBLE_PRECISION, "image": IMAGE, "xml": XML, "timestamp": TIMESTAMP, @@ -1700,6 +1716,9 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return " ".join([c for c in (spec, collation) if c is not None]) + def visit_double(self, type_, **kw): + return self.visit_DOUBLE_PRECISION(type_, **kw) + def visit_FLOAT(self, type_, **kw): precision = getattr(type_, "precision", None) if precision is None: diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index f6f10c476..c0e308383 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -819,6 +819,15 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): outconverter=lambda value: value.read(), arraysize=len_params, ) + elif ( + isinstance(type_impl, _OracleNumeric) + and type_impl.asdecimal + ): + out_parameters[name] = self.cursor.var( + decimal.Decimal, + arraysize=len_params, + ) + else: out_parameters[name] = self.cursor.var( dbtype, arraysize=len_params diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index 739cbc5a9..b98518099 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -55,6 +55,10 @@ class _PsycopgNumeric(sqltypes.Numeric): ) +class _PsycopgFloat(_PsycopgNumeric): + __visit_name__ = "float" + + class _PsycopgHStore(HSTORE): def bind_processor(self, dialect): if dialect._has_native_hstore: @@ -104,6 +108,7 @@ class _PGDialect_common_psycopg(PGDialect): PGDialect.colspecs, { sqltypes.Numeric: _PsycopgNumeric, + sqltypes.Float: _PsycopgFloat, HSTORE: _PsycopgHStore, sqltypes.ARRAY: _PsycopgARRAY, INT2VECTOR: _PsycopgINT2VECTOR, diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 3332f7ce2..b59cce374 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1196,6 +1196,10 @@ class SuiteRequirements(Requirements): return exclusions.closed() @property + def float_or_double_precision_behaves_generically(self): + return exclusions.closed() + + @property def precision_generic_float_type(self): """target backend will return native floating point numbers with at least seven decimal places when using the generic Float type. diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index ae54f6bcd..d49eb3284 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -1,13 +1,20 @@ # mypy: ignore-errors +from decimal import Decimal + +from . import testing from .. import fixtures from ..assertions import eq_ from ..config import requirements from ..schema import Column from ..schema import Table +from ... import Double +from ... import Float +from ... import Identity from ... import Integer from ... import literal from ... import literal_column +from ... import Numeric from ... import select from ... import String @@ -378,5 +385,109 @@ class ReturningTest(fixtures.TablesTest): eq_(rall, pks.all()) + @testing.combinations( + (Double(), 8.5514716, True), + ( + Double(53), + 8.5514716, + True, + testing.requires.float_or_double_precision_behaves_generically, + ), + (Float(), 8.5514, False), + ( + Float(8), + 8.5514, + True, + testing.requires.float_or_double_precision_behaves_generically, + ), + ( + Numeric(precision=15, scale=12, asdecimal=False), + 8.5514716, + True, + testing.requires.literal_float_coercion, + ), + ( + Numeric(precision=15, scale=12, asdecimal=True), + Decimal("8.5514716"), + False, + ), + argnames="type_,value,do_rounding", + ) + @testing.variation("sort_by_parameter_order", [True, False]) + @testing.variation("multiple_rows", [True, False]) + def test_insert_w_floats( + self, + connection, + metadata, + sort_by_parameter_order, + type_, + value, + do_rounding, + multiple_rows, + ): + """test #9701. + + this tests insertmanyvalues as well as decimal / floating point + RETURNING types + + """ + + t = Table( + "t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("value", type_), + ) + + t.create(connection) + + result = connection.execute( + t.insert().returning( + t.c.id, + t.c.value, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value}, + ) + + if multiple_rows: + i_range = range(1, 11) + else: + i_range = range(1, 2) + + # we want to test only that we are getting floating points back + # with some degree of the original value maintained, that it is not + # being truncated to an integer. there's too much variation in how + # drivers return floats, which should not be relied upon to be + # exact, for us to just compare as is (works for PG drivers but not + # others) so we use rounding here. There's precedent for this + # in suite/test_types.py::NumericTest as well + + if do_rounding: + eq_( + {(id_, round(val_, 5)) for id_, val_ in result}, + {(id_, round(value, 5)) for id_ in i_range}, + ) + + eq_( + { + round(val_, 5) + for val_ in connection.scalars(select(t.c.value)) + }, + {round(value, 5)}, + ) + else: + eq_( + set(result), + {(id_, value) for id_ in i_range}, + ) + + eq_( + set(connection.scalars(select(t.c.value))), + {value}, + ) + __all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") |