diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2019-11-09 16:21:55 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2019-11-09 16:21:55 +0000 |
commit | ed2c5f9ad1f92010e447797576ab4eef3beee21f (patch) | |
tree | 6e39e7366de4ac2fbb6a98e4e5938e33f34422a4 | |
parent | 4a2dd4902a1168234f14bdd0634728086d53c406 (diff) | |
parent | 3a0e0531c179e598c345e5be24e350c375ce7e22 (diff) | |
download | sqlalchemy-ed2c5f9ad1f92010e447797576ab4eef3beee21f.tar.gz |
Merge "Support for generated columns"
26 files changed, 903 insertions, 15 deletions
diff --git a/doc/build/changelog/unreleased_13/4894.rst b/doc/build/changelog/unreleased_13/4894.rst new file mode 100644 index 000000000..ee0f6f812 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4894.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: usecase, schema + :tickets: 4894 + + Added DDL support for "computed columns"; these are DDL column + specifications for columns that have a server-computed value, either upon + SELECT (known as "virtual") or at the point of which they are INSERTed or + UPDATEd (known as "stored"). Support is established for Postgresql, MySQL, + Oracle SQL Server and Firebird. Thanks to Federico Caselli for lots of work + on this one. + + .. seealso:: + + :ref:`computed_ddl` + diff --git a/doc/build/core/defaults.rst b/doc/build/core/defaults.rst index f0d724078..7a52f6177 100644 --- a/doc/build/core/defaults.rst +++ b/doc/build/core/defaults.rst @@ -536,9 +536,105 @@ including the default schema, if any. :ref:`oracle_returning` - in the Oracle dialect documentation +.. _computed_ddl: + +Computed (GENERATED ALWAYS AS) Columns +-------------------------------------- + +.. versionadded:: 1.3.11 + +The :class:`.Computed` construct allows a :class:`.Column` to be declared in +DDL as a "GENERATED ALWAYS AS" column, that is, one which has a value that is +computed by the database server. The construct accepts a SQL expression +typically declared textually using a string or the :func:`.text` construct, in +a similar manner as that of :class:`.CheckConstraint`. The SQL expression is +then interpreted by the database server in order to determine the value for the +column within a row. + +Example:: + + from sqlalchemy import Table, Column, MetaData, Integer, Computed + + metadata = MetaData() + + square = Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + +The DDL for the ``square`` table when run on a PostgreSQL 12 backend will look +like:: + + CREATE TABLE square ( + id SERIAL NOT NULL, + side INTEGER, + area INTEGER GENERATED ALWAYS AS (side * side) STORED, + perimeter INTEGER GENERATED ALWAYS AS (4 * side) STORED, + PRIMARY KEY (id) + ) + +Whether the value is persisted upon INSERT and UPDATE, or if it is calculated +on fetch, is an implementation detail of the database; the former is known as +"stored" and the latter is known as "virtual". Some database implementations +support both, but some only support one or the other. The optional +:paramref:`.Computed.persisted` flag may be specified as ``True`` or ``False`` +to indicate if the "STORED" or "VIRTUAL" keyword should be rendered in DDL, +however this will raise an error if the keyword is not supported by the target +backend; leaving it unset will use a working default for the target backend. + +The :class:`.Computed` construct is a subclass of the :class:`.FetchedValue` +object, and will set itself up as both the "server default" and "server +onupdate" generator for the target :class:`.Column`, meaning it will be treated +as a default generating column when INSERT and UPDATE statements are generated, +as well as that it will be fetched as a generating column when using the ORM. +This includes that it will be part of the RETURNING clause of the database +for databases which support RETURNING and the generated values are to be +eagerly fetched. + +.. note:: A :class:`.Column` that is defined with the :class:`.Computed` + construct may not store any value outside of that which the server applies + to it; SQLAlchemy's behavior when a value is passed for such a column + to be written in INSERT or UPDATE is currently that the value will be + ignored. + +"GENERATED ALWAYS AS" is currently known to be supported by: + +* MySQL version 5.7 and onwards + +* MariaDB 10.x series and onwards + +* PostgreSQL as of version 12 + +* Oracle - with the caveat that RETURNING does not work correctly with UPDATE + (a warning will be emitted to this effect when the UPDATE..RETURNING that + includes a computed column is rendered) + +* Microsoft SQL Server + +* Firebird + +When :class:`.Computed` is used with an unsupported backend, if the target +dialect does not support it, a :class:`.CompileError` is raised when attempting +to render the construct. Otherwise, if the dialect supports it but the +particular database server version in use does not, then a subclass of +:class:`.DBAPIError`, usually :class:`.OperationalError`, is raised when the +DDL is emitted to the database. + +.. seealso:: + + :class:`.Computed` + Default Objects API ------------------- +.. autoclass:: Computed + :members: + + .. autoclass:: ColumnDefault diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index a1e166a0d..92128ed22 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -17,6 +17,7 @@ from .schema import DefaultClause # noqa from .schema import FetchedValue # noqa from .schema import ForeignKey # noqa from .schema import ForeignKeyConstraint # noqa +from .schema import Computed # noqa from .schema import Index # noqa from .schema import MetaData # noqa from .schema import PrimaryKeyConstraint # noqa diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index c7c921cb4..052ecf05b 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -581,6 +581,17 @@ class FBDDLCompiler(sql.compiler.DDLCompiler): drop.element ) + def visit_computed_column(self, generated): + if generated.persisted is not None: + raise exc.CompileError( + "Firebird computed columns do not support a persistence " + "method setting; set the 'persisted' flag to None for " + "Firebird support." + ) + return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): """Install Firebird specific reserved words.""" diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 6c7b732ce..5d4de4a33 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1942,13 +1942,15 @@ class MSSQLStrictCompiler(MSSQLCompiler): class MSDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = ( - self.preparer.format_column(column) - + " " - + self.dialect.type_compiler.process( + colspec = self.preparer.format_column(column) + + # type is not accepted in a computed column + if column.computed is not None: + colspec += " " + self.process(column.computed) + else: + colspec += " " + self.dialect.type_compiler.process( column.type, type_expression=column ) - ) if column.nullable is not None: if ( @@ -1958,7 +1960,8 @@ class MSDDLCompiler(compiler.DDLCompiler): or column.autoincrement is True ): colspec += " NOT NULL" - else: + elif column.computed is None: + # don't specify "NULL" for computed columns colspec += " NULL" if column.table is None: @@ -2110,6 +2113,15 @@ class MSDDLCompiler(compiler.DDLCompiler): text += self.define_constraint_deferrability(constraint) return text + def visit_computed_column(self, generated): + text = "AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + # explicitly check for True|False since None means server default + if generated.persisted is True: + text += " PERSISTED" + return text + class MSIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 73484aea1..05edb6310 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1503,6 +1503,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler): ), ] + if column.computed is not None: + colspec.append(self.process(column.computed)) + is_timestamp = isinstance( column.type._unwrapped_dialect_impl(self.dialect), sqltypes.TIMESTAMP, diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 4c5a717b9..c1e91fb12 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -446,6 +446,7 @@ columns for non-unique indexes, all but the last column for unique indexes). from itertools import groupby import re +from ... import Computed from ... import exc from ... import schema as sa_schema from ... import sql @@ -905,6 +906,16 @@ class OracleCompiler(compiler.SQLCompiler): for i, column in enumerate( expression._select_iterables(returning_cols) ): + if self.isupdate and isinstance(column.server_default, Computed): + util.warn( + "Computed columns don't work with Oracle UPDATE " + "statements that use RETURNING; the value of the column " + "*before* the UPDATE takes place is returned. It is " + "advised to not use RETURNING with an Oracle computed " + "column. Consider setting implicit_returning to False on " + "the Table object in order to avoid implicit RETURNING " + "clauses from being generated for this Table." + ) if column.type._has_column_expression: col_expr = column.type.column_expression(column) else: @@ -1186,6 +1197,19 @@ class OracleDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) + def visit_computed_column(self, generated): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + raise exc.CompileError( + "Oracle computed columns do not support 'stored' persistence; " + "set the 'persisted' flag to None or False for Oracle support." + ) + elif generated.persisted is False: + text += " VIRTUAL" + return text + class OracleIdentifierPreparer(compiler.IdentifierPreparer): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index e94f9913c..d6fd2623b 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1873,6 +1873,9 @@ class PGDDLCompiler(compiler.DDLCompiler): if default is not None: colspec += " DEFAULT " + default + if column.computed is not None: + colspec += " " + self.process(column.computed) + if not column.nullable: colspec += " NOT NULL" return colspec @@ -2043,6 +2046,18 @@ class PGDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) + def visit_computed_column(self, generated): + if generated.persisted is False: + raise exc.CompileError( + "PostrgreSQL computed columns do not support 'virtual' " + "persistence; set the 'persisted' flag to None or True for " + "PostgreSQL support." + ) + + return "GENERATED ALWAYS AS (%s) STORED" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TSVECTOR(self, type_, **kw): diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c1f108a0d..02d44a260 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1032,6 +1032,9 @@ class SQLiteCompiler(compiler.SQLCompiler): class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): + if column.computed is not None: + raise exc.CompileError("SQLite does not support computed columns") + coltype = self.dialect.type_compiler.process( column.type, type_expression=column ) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index f50acdb1c..6adeded36 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -47,6 +47,7 @@ from .sql.schema import DefaultGenerator # noqa from .sql.schema import FetchedValue # noqa from .sql.schema import ForeignKey # noqa from .sql.schema import ForeignKeyConstraint # noqa +from .sql.schema import Computed # noqa from .sql.schema import Index # noqa from .sql.schema import MetaData # noqa from .sql.schema import PrimaryKeyConstraint # noqa diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 546fffc6c..85c1750b7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3173,6 +3173,9 @@ class DDLCompiler(Compiled): if default is not None: colspec += " DEFAULT " + default + if column.computed is not None: + colspec += " " + self.process(column.computed) + if not column.nullable: colspec += " NOT NULL" return colspec @@ -3314,6 +3317,16 @@ class DDLCompiler(Compiled): text += " MATCH %s" % constraint.match return text + def visit_computed_column(self, generated): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + text += " STORED" + elif generated.persisted is False: + text += " VIRTUAL" + return text + class GenericTypeCompiler(TypeCompiler): def visit_FLOAT(self, type_, **kw): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index ee7dc61ce..8c325538c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1028,9 +1028,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): :class:`.SchemaItem` derived constructs which will be applied as options to the column. These include instances of :class:`.Constraint`, :class:`.ForeignKey`, :class:`.ColumnDefault`, - and :class:`.Sequence`. In some cases an equivalent keyword - argument is available such as ``server_default``, ``default`` - and ``unique``. + :class:`.Sequence`, :class:`.Computed`. In some cases an + equivalent keyword argument is available such as ``server_default``, + ``default`` and ``unique``. :param autoincrement: Set up "auto increment" semantics for an integer primary key column. The default value is the string ``"auto"`` @@ -1296,6 +1296,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): self.constraints = set() self.foreign_keys = set() self.comment = kwargs.pop("comment", None) + self.computed = None # check if this Column is proxying another column if "_proxies" in kwargs: @@ -1502,6 +1503,12 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): c.copy(**kw) for c in self.constraints if not c._type_bound ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint] + server_default = self.server_default + server_onupdate = self.server_onupdate + if isinstance(server_default, Computed): + server_default = server_onupdate = None + args.append(self.server_default.copy(**kw)) + type_ = self.type if isinstance(type_, SchemaEventTarget): type_ = type_.copy(**kw) @@ -1518,9 +1525,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): index=self.index, autoincrement=self.autoincrement, default=self.default, - server_default=self.server_default, + server_default=server_default, onupdate=self.onupdate, - server_onupdate=self.server_onupdate, + server_onupdate=server_onupdate, doc=self.doc, comment=self.comment, *args @@ -4348,3 +4355,89 @@ class _SchemaTranslateMap(object): _default_schema_map = _SchemaTranslateMap(None) _schema_getter = _SchemaTranslateMap._schema_getter + + +class Computed(FetchedValue, SchemaItem): + """Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax. + + The :class:`.Computed` construct is an inline construct added to the + argument list of a :class:`.Column` object:: + + from sqlalchemy import Computed + + Table('square', meta, + Column('side', Float, nullable=False), + Column('area', Float, Computed('side * side')) + ) + + See the linked documentation below for complete details. + + .. versionadded:: 1.3.11 + + .. seealso:: + + :ref:`computed_ddl` + + """ + + __visit_name__ = "computed_column" + + @_document_text_coercion( + "sqltext", ":class:`.Computed`", ":paramref:`.Computed.sqltext`" + ) + def __init__(self, sqltext, persisted=None): + """Construct a GENERATED ALWAYS AS DDL construct to accompany a + :class:`.Column`. + + :param sqltext: + A string containing the column generation expression, which will be + used verbatim, or a SQL expression construct, such as a :func:`.text` + object. If given as a string, the object is converted to a + :func:`.text` object. + + :param persisted: + Optional, controls how this column should be persisted by the + database. Possible values are: + + * None, the default, it will use the default persistence defined + by the database. + * True, will render ``GENERATED ALWAYS AS ... STORED``, or the + equivalent for the target database if supported + * False, will render ``GENERATED ALWAYS AS ... VIRTUAL``, or the + equivalent for the target database if supported. + + Specifying ``True`` or ``False`` may raise an error when the DDL + is emitted to the target database if the databse does not support + that persistence option. Leaving this parameter at its default + of ``None`` is guaranteed to succeed for all databases that support + ``GENERATED ALWAYS AS``. + + """ + self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext) + self.persisted = persisted + self.column = None + + def _set_parent(self, parent): + if not isinstance( + parent.server_default, (type(None), Computed) + ) or not isinstance(parent.server_onupdate, (type(None), Computed)): + raise exc.ArgumentError( + "A generated column cannot specify a server_default or a " + "server_onupdate argument" + ) + self.column = parent + parent.computed = self + self.column.server_onupdate = self + self.column.server_default = self + + def _as_for_update(self, for_update): + return self + + def copy(self, target_table=None, **kw): + if target_table is not None: + sqltext = _copy_expression(self.sqltext, self.table, target_table) + else: + sqltext = self.sqltext + g = Computed(sqltext, persisted=self.persisted) + + return self._schema_item_copy(g) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 87bbc6a0f..8262142ec 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -44,7 +44,7 @@ def combinations(*comb, **kw): well as if it is included in the tokens used to create the id of the parameter set. - If omitted, the argment combinations are passed to parametrize as is. If + If omitted, the argument combinations are passed to parametrize as is. If passed, each argument combination is turned into a pytest.param() object, mapping the elements of the argument tuple to produce an id based on a character value in the same position within the string template using the @@ -59,9 +59,12 @@ def combinations(*comb, **kw): r - the given argument should be passed and it should be added to the id by calling repr() - s- the given argument should be passed and it should be added to the + s - the given argument should be passed and it should be added to the id by calling str() + a - (argument) the given argument should be passed and it should not + be used to generated the id + e.g.:: @testing.combinations( diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index c45156d6b..fd8d82690 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1064,3 +1064,7 @@ class SuiteRequirements(Requirements): return True except ImportError: return False + + @property + def computed_columns(self): + return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 02cdcf4f5..9db2daf7a 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -9,6 +9,7 @@ from ..schema import Column from ..schema import Table from ... import bindparam from ... import case +from ... import Computed from ... import false from ... import func from ... import Integer @@ -858,3 +859,47 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + +class ComputedColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + + @classmethod + def insert_data(cls): + with config.db.begin() as conn: + conn.execute( + cls.tables.square.insert(), + [{"id": 1, "side": 10}, {"id": 10, "side": 42}], + ) + + def test_select_all(self): + with config.db.connect() as conn: + res = conn.execute( + select([text("*")]) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)]) + + def test_select_columns(self): + with config.db.connect() as conn: + res = conn.execute( + select( + [self.tables.square.c.area, self.tables.square.c.perimeter] + ) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(100, 40), (1764, 168)]) diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 00a8a08fc..6b3244c30 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -1,5 +1,6 @@ # -*- encoding: utf-8 from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import delete from sqlalchemy import extract from sqlalchemy import func @@ -1193,6 +1194,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT TRY_CAST (t1.id AS INTEGER) AS id FROM t1", ) + @testing.combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_true", " PERSISTED", True), + ("persisted_false", "", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER NULL, y AS (x + 2)%s)" % text, + ) + class SchemaTest(fixtures.TestBase): def setup(self): diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 301562d1c..d59c0549f 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -8,6 +8,7 @@ from sqlalchemy import CHAR from sqlalchemy import CheckConstraint from sqlalchemy import CLOB from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import DATE from sqlalchemy import Date from sqlalchemy import DATETIME @@ -386,6 +387,28 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile(sql.delete(a1), "DELETE FROM t1 AS a1") + @testing.combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_true", " STORED", True), + ("persisted_false", " VIRTUAL", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + class SQLTest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 32ef7fd78..f73eb96c0 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -1,6 +1,8 @@ # coding: utf-8 from sqlalchemy import and_ from sqlalchemy import bindparam +from sqlalchemy import Computed +from sqlalchemy import exc from sqlalchemy import except_ from sqlalchemy import ForeignKey from sqlalchemy import func @@ -27,6 +29,7 @@ from sqlalchemy.engine import default from sqlalchemy.sql import column from sqlalchemy.sql import quoted_name from sqlalchemy.sql import table +from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -1088,6 +1091,40 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "t1.c2, t1.c3 INTO :ret_0, :ret_1", ) + def test_returning_insert_computed(self): + m = MetaData() + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + self.assert_compile( + t1.insert().values(id=1, foo=5).returning(t1.c.bar), + "INSERT INTO t1 (id, foo) VALUES (:id, :foo) " + "RETURNING t1.bar INTO :ret_0", + ) + + def test_returning_update_computed_warning(self): + m = MetaData() + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + with testing.expect_warnings( + "Computed columns don't work with Oracle UPDATE" + ): + self.assert_compile( + t1.update().values(id=1, foo=5).returning(t1.c.bar), + "UPDATE t1 SET id=:id, foo=:foo RETURNING t1.bar INTO :ret_0", + ) + def test_compound(self): t1 = table("t1", column("c1"), column("c2"), column("c3")) t2 = table("t2", column("c1"), column("c2"), column("c3")) @@ -1186,6 +1223,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE BITMAP INDEX idx3 ON testtbl (data)", ) + @testing.combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_false", " VIRTUAL", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + + def test_column_computed_persisted_true(self): + m = MetaData() + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", persisted=True)), + ) + assert_raises_message( + exc.CompileError, + r".*Oracle computed columns do not support 'stored' ", + schema.CreateTable(t).compile, + dialect=oracle.dialect(), + ) + class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): def test_basic(self): diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 62926700e..20c6336b8 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -3,6 +3,7 @@ import re from sqlalchemy import bindparam +from sqlalchemy import Computed from sqlalchemy import create_engine from sqlalchemy import exc from sqlalchemy import Float @@ -258,6 +259,72 @@ class EncodingErrorsTest(fixtures.TestBase): ) +class ComputedReturningTest(fixtures.TablesTest): + __only_on__ = "oracle" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + Table( + "test_no_returning", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + implicit_returning=False, + ) + + def test_computed_insert(self): + test = self.tables.test + with testing.db.connect() as conn: + result = conn.execute( + test.insert().return_defaults(), {"id": 1, "foo": 5} + ) + + eq_(result.returned_defaults, (47,)) + + eq_(conn.scalar(select([test.c.bar])), 47) + + def test_computed_update_warning(self): + test = self.tables.test + with testing.db.connect() as conn: + conn.execute(test.insert(), {"id": 1, "foo": 5}) + + with testing.expect_warnings( + "Computed columns don't work with Oracle UPDATE" + ): + result = conn.execute( + test.update().values(foo=10).return_defaults() + ) + + # returns the *old* value + eq_(result.returned_defaults, (47,)) + + eq_(conn.scalar(select([test.c.bar])), 52) + + def test_computed_update_no_warning(self): + test = self.tables.test_no_returning + with testing.db.connect() as conn: + conn.execute(test.insert(), {"id": 1, "foo": 5}) + + result = conn.execute( + test.update().values(foo=10).return_defaults() + ) + + # no returning + eq_(result.returned_defaults, None) + + eq_(conn.scalar(select([test.c.bar])), 52) + + class OutParamTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "oracle+cx_oracle" __backend__ = True diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 83e3ee3fd..4c4c43281 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -3,6 +3,7 @@ from sqlalchemy import and_ from sqlalchemy import cast from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import delete from sqlalchemy import Enum from sqlalchemy import exc @@ -1541,6 +1542,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): q, "DELETE FROM t1 AS a1 USING t2 WHERE a1.c1 = t2.c1" ) + @testing.combinations( + ("no_persisted", " STORED", "ignore"), + ("persisted_none", " STORED", None), + ("persisted_true", " STORED", True), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + + def test_column_computed_persisted_false(self): + m = MetaData() + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", persisted=False)), + ) + assert_raises_message( + exc.CompileError, + "PostrgreSQL computed columns do not support 'virtual'", + schema.CreateTable(t).compile, + dialect=postgresql.dialect(), + ) + class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index f1ce321e3..ad146bc77 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -1,4 +1,5 @@ from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import func @@ -438,6 +439,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(column("_somecol"), '"_somecol"') self.assert_compile(column("$somecol"), '"$somecol"') + @testing.combinations( + ("no_persisted", "ignore"), ("persisted_none", None), id_="ia" + ) + def test_column_computed(self, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2))", + ) + + @testing.combinations( + ("persisted_true", True), ("persisted_false", False), id_="ia" + ) + def test_column_computed_raises(self, persisted): + m = MetaData() + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", persisted=persisted)), + ) + assert_raises_message( + exc.CompileError, + "Firebird computed columns do not support a persistence method", + schema.CreateTable(t).compile, + dialect=firebird.dialect(), + ) + class TypesTest(fixtures.TestBase): __only_on__ = "firebird" diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 7d9c75175..638d84334 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -10,6 +10,7 @@ from sqlalchemy import bindparam from sqlalchemy import CheckConstraint from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import Computed from sqlalchemy import create_engine from sqlalchemy import DefaultClause from sqlalchemy import event @@ -756,6 +757,29 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): url = make_url(url) eq_(d.create_connect_args(url), expected) + @testing.combinations( + ("no_persisted", "ignore"), + ("persisted_none", None), + ("persisted_true", True), + ("persisted_false", False), + id_="ia", + ) + def test_column_computed(self, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + assert_raises_message( + exc.CompileError, + "SQLite does not support computed columns", + schema.CreateTable(t).compile, + dialect=sqlite.dialect(), + ) + class AttachedDBTest(fixtures.TestBase): __only_on__ = "sqlite" diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index 5fbe091be..9f37dbf4d 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -1,11 +1,16 @@ import sqlalchemy as sa +from sqlalchemy import Computed from sqlalchemy import event from sqlalchemy import Integer from sqlalchemy import String +from sqlalchemy import testing from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper +from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.assertsql import assert_engine +from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -170,3 +175,173 @@ class ExcludedDefaultsTest(fixtures.MappedTest): sess.add(f1) sess.flush() eq_(dt.select().execute().fetchall(), [(1, "hello")]) + + +class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): + """test that computed columns are recognized as server + oninsert/onupdate defaults.""" + + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + @classmethod + def setup_classes(cls): + class Thing(cls.Basic): + pass + + class ThingNoEager(cls.Basic): + pass + + @classmethod + def setup_mappers(cls): + Thing = cls.classes.Thing + + mapper(Thing, cls.tables.test, eager_defaults=True) + + ThingNoEager = cls.classes.ThingNoEager + mapper(ThingNoEager, cls.tables.test, eager_defaults=False) + + @testing.combinations(("eager", True), ("noneager", False), id_="ia") + def test_insert_computed(self, eager): + if eager: + Thing = self.classes.Thing + else: + Thing = self.classes.ThingNoEager + + s = Session() + + t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=10)) + + s.add_all([t1, t2]) + + with assert_engine(testing.db) as asserter: + s.flush() + eq_(t1.bar, 5 + 42) + eq_(t2.bar, 10 + 42) + + if eager and testing.db.dialect.implicit_returning: + asserter.assert_( + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, %(foo)s) " + "RETURNING test.bar", + [{"foo": 5, "id": 1}], + dialect="postgresql", + ), + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, %(foo)s) " + "RETURNING test.bar", + [{"foo": 10, "id": 2}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, :foo)", + [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 1}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 2}], + ), + ) + + @testing.requires.computed_columns_on_update_returning + def test_update_computed_eager(self): + self._test_update_computed(True) + + def test_update_computed_noneager(self): + self._test_update_computed(False) + + def _test_update_computed(self, eager): + if eager: + Thing = self.classes.Thing + else: + Thing = self.classes.ThingNoEager + + s = Session() + + t1, t2 = (Thing(id=1, foo=1), Thing(id=2, foo=2)) + + s.add_all([t1, t2]) + s.flush() + + t1.foo = 5 + t2.foo = 6 + + with assert_engine(testing.db) as asserter: + s.flush() + eq_(t1.bar, 5 + 42) + eq_(t2.bar, 6 + 42) + + if eager and testing.db.dialect.implicit_returning: + asserter.assert_( + CompiledSQL( + "UPDATE test SET foo=%(foo)s " + "WHERE test.id = %(test_id)s " + "RETURNING test.bar", + [{"foo": 5, "test_id": 1}], + dialect="postgresql", + ), + CompiledSQL( + "UPDATE test SET foo=%(foo)s " + "WHERE test.id = %(test_id)s " + "RETURNING test.bar", + [{"foo": 6, "test_id": 2}], + dialect="postgresql", + ), + ) + elif eager: + asserter.assert_( + CompiledSQL( + "UPDATE test SET foo=:foo WHERE test.id = :test_id", + [{"foo": 5, "test_id": 1}], + ), + CompiledSQL( + "UPDATE test SET foo=:foo WHERE test.id = :test_id", + [{"foo": 6, "test_id": 2}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 1}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 2}], + ), + ) + else: + asserter.assert_( + CompiledSQL( + "UPDATE test SET foo=:foo WHERE test.id = :test_id", + [{"foo": 5, "test_id": 1}, {"foo": 6, "test_id": 2}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 1}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 2}], + ), + ) diff --git a/test/requirements.py b/test/requirements.py index 42a24c3a9..fd093f270 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -373,6 +373,10 @@ class DefaultRequirements(SuiteRequirements): return only_if([self.returning, self.sqlite]) @property + def computed_columns_on_update_returning(self): + return self.computed_columns + skip_if("oracle") + + @property def correlated_outer_joins(self): """Target must support an outer join to a subquery which correlates to the parent.""" @@ -774,8 +778,9 @@ class DefaultRequirements(SuiteRequirements): @property def nullsordering(self): """Target backends that support nulls ordering.""" - return fails_on_everything_except("postgresql", "oracle", "firebird", - "sqlite >= 3.30.0") + return fails_on_everything_except( + "postgresql", "oracle", "firebird", "sqlite >= 3.30.0" + ) @property def reflects_pk_names(self): @@ -1439,3 +1444,7 @@ class DefaultRequirements(SuiteRequirements): lambda config: against(config, "oracle+cx_oracle") and config.db.dialect.cx_oracle_ver < (6,) ) + + @property + def computed_columns(self): + return skip_if(["postgresql < 12", "sqlite", "mysql < 5.7"]) diff --git a/test/sql/test_computed.py b/test/sql/test_computed.py new file mode 100644 index 000000000..2999c621c --- /dev/null +++ b/test/sql/test_computed.py @@ -0,0 +1,80 @@ +# coding: utf-8 +from sqlalchemy import Column +from sqlalchemy import Computed +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table +from sqlalchemy.exc import ArgumentError +from sqlalchemy.schema import CreateTable +from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import combinations +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_not_ + + +class DDLComputedTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + @combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_true", " STORED", True), + ("persisted_false", " VIRTUAL", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + + def test_server_default_onupdate(self): + text = ( + "A generated column cannot specify a server_default or a " + "server_onupdate argument" + ) + + def fn(**kwargs): + m = MetaData() + Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2"), **kwargs), + ) + + assert_raises_message(ArgumentError, text, fn, server_default="42") + assert_raises_message(ArgumentError, text, fn, server_onupdate="42") + + def test_tometadata(self): + comp1 = Computed("x + 2") + m = MetaData() + t = Table("t", m, Column("x", Integer), Column("y", Integer, comp1)) + is_(comp1.column, t.c.y) + is_(t.c.y.server_onupdate, comp1) + is_(t.c.y.server_default, comp1) + + m2 = MetaData() + t2 = t.tometadata(m2) + comp2 = t2.c.y.server_default + + is_not_(comp1, comp2) + + is_(comp1.column, t.c.y) + is_(t.c.y.server_onupdate, comp1) + is_(t.c.y.server_default, comp1) + + is_(comp2.column, t2.c.y) + is_(t2.c.y.server_onupdate, comp2) + is_(t2.c.y.server_default, comp2) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index e08c35bfb..05e5ec3c2 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -651,6 +651,8 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): class ToMetaDataTest(fixtures.TestBase, ComparesTables): @testing.requires.check_constraints def test_copy(self): + # TODO: modernize this test + from sqlalchemy.testing.schema import Table meta = MetaData() |