summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-11-07 18:40:03 -0500
committermike bayer <mike_mp@zzzcomputing.com>2022-11-11 16:20:00 +0000
commit8e91cfe529b9b0150c16e52e22e4590bfbbe79fd (patch)
treedc8328ae669164a8fe7cf9c8a821ba92a9057921
parente3a8d198917f4246365e09fa975d55c64082cd2e (diff)
downloadsqlalchemy-8e91cfe529b9b0150c16e52e22e4590bfbbe79fd.tar.gz
establish consistency for RETURNING column labels
The RETURNING clause now renders columns using the routine as that of the :class:`.Select` to generate labels, which will include disambiguating labels, as well as that a SQL function surrounding a named column will be labeled using the column name itself. This is a more comprehensive change than a similar one made for the 1.4 series that adjusted the function label issue only. includes 1.4's changelog for the backported version which also fixes an Oracle issue independently of the 2.0 series. Fixes: #8770 Change-Id: I2ab078a214a778ffe1720dbd864ae4c105a0691d
-rw-r--r--doc/build/changelog/unreleased_14/8770.rst23
-rw-r--r--doc/build/changelog/unreleased_20/8770.rst10
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py19
-rw-r--r--lib/sqlalchemy/sql/compiler.py23
-rw-r--r--lib/sqlalchemy/sql/dml.py4
-rw-r--r--lib/sqlalchemy/sql/selectable.py8
-rw-r--r--test/dialect/mssql/test_compiler.py29
-rw-r--r--test/dialect/oracle/test_compiler.py31
-rw-r--r--test/dialect/postgresql/test_compiler.py30
-rw-r--r--test/sql/test_labels.py97
-rw-r--r--test/sql/test_returning.py44
11 files changed, 310 insertions, 8 deletions
diff --git a/doc/build/changelog/unreleased_14/8770.rst b/doc/build/changelog/unreleased_14/8770.rst
new file mode 100644
index 000000000..8968b0361
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/8770.rst
@@ -0,0 +1,23 @@
+.. change::
+ :tags: bug, postgresql, mssql
+ :tickets: 8770
+
+ For the PostgreSQL and SQL Server dialects only, adjusted the compiler so
+ that when rendering column expressions in the RETURNING clause, the "non
+ anon" label that's used in SELECT statements is suggested for SQL
+ expression elements that generate a label; the primary example is a SQL
+ function that may be emitting as part of the column's type, where the label
+ name should match the column's name by default. This restores a not-well
+ defined behavior that had changed in version 1.4.21 due to :ticket:`6718`,
+ :ticket:`6710`. The Oracle dialect has a different RETURNING implementation
+ and was not affected by this issue. Version 2.0 features an across the
+ board change for its widely expanded support of RETURNING on other
+ backends.
+
+
+.. change::
+ :tags: bug, oracle
+
+ Fixed issue in the Oracle dialect where an INSERT statement that used
+ ``insert(some_table).values(...).returning(some_table)`` against a full
+ :class:`.Table` object at once would fail to execute, raising an exception.
diff --git a/doc/build/changelog/unreleased_20/8770.rst b/doc/build/changelog/unreleased_20/8770.rst
new file mode 100644
index 000000000..59b94d658
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8770.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 8770
+
+ The RETURNING clause now renders columns using the routine as that of the
+ :class:`.Select` to generate labels, which will include disambiguating
+ labels, as well as that a SQL function surrounding a named column will be
+ labeled using the column name itself. This is a more comprehensive change
+ than a similar one made for the 1.4 series that adjusted the function label
+ issue only.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index a338ba27a..53fe96c9a 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2295,11 +2295,24 @@ class MSSQLCompiler(compiler.SQLCompiler):
columns = [
self._label_returning_column(
stmt,
- adapter.traverse(c),
+ adapter.traverse(column),
populate_result_map,
- {"result_map_targets": (c,)},
+ {"result_map_targets": (column,)},
+ fallback_label_name=fallback_label_name,
+ column_is_repeated=repeated,
+ name=name,
+ proxy_name=proxy_name,
+ **kw,
+ )
+ for (
+ name,
+ proxy_name,
+ fallback_label_name,
+ column,
+ repeated,
+ ) in stmt._generate_columns_plus_names(
+ True, cols=expression._select_iterables(returning_cols)
)
- for c in expression._select_iterables(returning_cols)
]
return "OUTPUT " + ", ".join(columns)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 3e62cb350..97397e9cf 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -3760,7 +3760,6 @@ class SQLCompiler(Compiled):
"_label_select_column is only relevant within "
"the columns clause of a SELECT or RETURNING"
)
-
if isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
@@ -4416,9 +4415,27 @@ class SQLCompiler(Compiled):
populate_result_map: bool,
**kw: Any,
) -> str:
+
columns = [
- self._label_returning_column(stmt, c, populate_result_map, **kw)
- for c in base._select_iterables(returning_cols)
+ self._label_returning_column(
+ stmt,
+ column,
+ populate_result_map,
+ fallback_label_name=fallback_label_name,
+ column_is_repeated=repeated,
+ name=name,
+ proxy_name=proxy_name,
+ **kw,
+ )
+ for (
+ name,
+ proxy_name,
+ fallback_label_name,
+ column,
+ repeated,
+ ) in stmt._generate_columns_plus_names(
+ True, cols=base._select_iterables(returning_cols)
+ )
]
return "RETURNING " + ", ".join(columns)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 5145a4a16..2d3e3598b 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -59,6 +59,7 @@ from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import Join
+from .selectable import SelectLabelStyle
from .selectable import TableClause
from .selectable import TypedReturnsRows
from .sqltypes import NullType
@@ -399,6 +400,9 @@ class UpdateBase(
] = util.EMPTY_DICT
named_with_column = False
+ _label_style: SelectLabelStyle = (
+ SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY
+ )
table: _DMLTableElement
_return_defaults = False
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 9de015774..488dfe721 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -2193,7 +2193,9 @@ class SelectsRows(ReturnsRows):
_label_style: SelectLabelStyle = LABEL_STYLE_NONE
def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
+ self,
+ anon_for_dupe_key: bool,
+ cols: Optional[_SelectIterable] = None,
) -> List[_ColumnsPlusNames]:
"""Generate column names as rendered in a SELECT statement by
the compiler.
@@ -2204,7 +2206,9 @@ class SelectsRows(ReturnsRows):
_column_naming_convention as well.
"""
- cols = self._all_selected_columns
+
+ if cols is None:
+ cols = self._all_selected_columns
key_naming_convention = SelectState._column_naming_convention(
self._label_style
diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py
index 8605ea9c0..b575595ac 100644
--- a/test/dialect/mssql/test_compiler.py
+++ b/test/dialect/mssql/test_compiler.py
@@ -36,6 +36,7 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing.assertions import eq_ignore_whitespace
+from sqlalchemy.types import TypeEngine
tbl = table("t", column("a"))
@@ -119,6 +120,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"Latin1_General_CS_AS_KS_WS_CI ASC",
)
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) OUTPUT inserted.name, "
+ "lower(inserted.value) AS value VALUES (:name, :value)",
+ )
+
def test_join_with_hint(self):
t1 = table(
"t1",
diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py
index 2973c6e39..8981e74e8 100644
--- a/test/dialect/oracle/test_compiler.py
+++ b/test/dialect/oracle/test_compiler.py
@@ -9,6 +9,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy import Index
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import literal
from sqlalchemy import literal_column
@@ -42,6 +43,7 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import eq_ignore_whitespace
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -1359,6 +1361,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"t1.c2, t1.c3 INTO :ret_0, :ret_1",
)
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, lower(some_table.value) "
+ "INTO :ret_0, :ret_1",
+ )
+
def test_returning_insert_computed(self):
m = MetaData()
t1 = Table(
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 96a8e7d5a..338d0da4e 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -61,6 +61,7 @@ from sqlalchemy.testing.assertions import AssertsCompiledSQL
from sqlalchemy.testing.assertions import eq_ignore_whitespace
from sqlalchemy.testing.assertions import expect_warnings
from sqlalchemy.testing.assertions import is_
+from sqlalchemy.types import TypeEngine
from sqlalchemy.util import OrderedDict
@@ -200,6 +201,35 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
dialect=dialect,
)
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) "
+ "VALUES (%(name)s, %(value)s) RETURNING some_table.name, "
+ "lower(some_table.value) AS value",
+ )
+
def test_create_drop_enum(self):
# test escaping and unicode within CREATE TYPE for ENUM
typ = postgresql.ENUM("val1", "val2", "val's 3", "méil", name="myname")
diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py
index 42d9c5f00..a74c5811c 100644
--- a/test/sql/test_labels.py
+++ b/test/sql/test_labels.py
@@ -2,6 +2,8 @@ from sqlalchemy import bindparam
from sqlalchemy import Boolean
from sqlalchemy import cast
from sqlalchemy import exc as exceptions
+from sqlalchemy import func
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import literal_column
from sqlalchemy import MetaData
@@ -32,6 +34,7 @@ from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import TypeEngine
IDENT_LENGTH = 29
@@ -827,6 +830,100 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
return SomeColThing
+ @testing.fixture
+ def compiler_column_fixture(self):
+ return self._fixture()
+
+ @testing.fixture
+ def column_expression_fixture(self):
+ class MyString(TypeEngine):
+ def column_expression(self, column):
+ return func.lower(column)
+
+ return table(
+ "some_table", column("name", String), column("value", MyString)
+ )
+
+ def test_plain_select_compiler_expression(self, compiler_column_fixture):
+ expr = compiler_column_fixture
+ table1 = self.table1
+
+ self.assert_compile(
+ select(
+ table1.c.name,
+ expr(table1.c.value),
+ ),
+ "SELECT some_table.name, SOME_COL_THING(some_table.value) "
+ "AS value FROM some_table",
+ )
+
+ def test_plain_select_column_expression(self, column_expression_fixture):
+ table1 = column_expression_fixture
+
+ self.assert_compile(
+ select(table1),
+ "SELECT some_table.name, lower(some_table.value) AS value "
+ "FROM some_table",
+ )
+
+ def test_plain_returning_compiler_expression(
+ self, compiler_column_fixture
+ ):
+ expr = compiler_column_fixture
+ table1 = self.table1
+
+ self.assert_compile(
+ insert(table1).returning(
+ table1.c.name,
+ expr(table1.c.value),
+ ),
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, "
+ "SOME_COL_THING(some_table.value) AS value",
+ )
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns
+ ):
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = insert(table1).returning(table1)
+ else:
+ stmt = insert(table1).returning(table1.c.name, table1.c.value)
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, lower(some_table.value) AS value",
+ )
+
+ def test_select_dupes_column_expression(self, column_expression_fixture):
+ table1 = column_expression_fixture
+
+ self.assert_compile(
+ select(table1.c.name, table1.c.value, table1.c.value),
+ "SELECT some_table.name, lower(some_table.value) AS value, "
+ "lower(some_table.value) AS value__1 FROM some_table",
+ )
+
+ def test_returning_dupes_column_expression(
+ self, column_expression_fixture
+ ):
+ table1 = column_expression_fixture
+
+ stmt = insert(table1).returning(
+ table1.c.name, table1.c.value, table1.c.value
+ )
+
+ self.assert_compile(
+ stmt,
+ "INSERT INTO some_table (name, value) VALUES (:name, :value) "
+ "RETURNING some_table.name, lower(some_table.value) AS value, "
+ "lower(some_table.value) AS value__1",
+ )
+
def test_column_auto_label_dupes_label_style_none(self):
expr = self._fixture()
table1 = self.table1
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index 32d4c7740..e0299e334 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -415,6 +415,50 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
result = connection.execute(ins)
eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
+ @testing.fixture
+ def column_expression_fixture(self, metadata, connection):
+ class MyString(TypeDecorator):
+ cache_ok = True
+ impl = String(50)
+
+ def column_expression(self, column):
+ return func.lower(column)
+
+ t1 = Table(
+ "some_table",
+ metadata,
+ Column("name", String(50)),
+ Column("value", MyString(50)),
+ )
+ metadata.create_all(connection)
+ return t1
+
+ @testing.combinations("columns", "table", argnames="use_columns")
+ def test_plain_returning_column_expression(
+ self, column_expression_fixture, use_columns, connection
+ ):
+ """test #8770"""
+ table1 = column_expression_fixture
+
+ if use_columns == "columns":
+ stmt = (
+ insert(table1)
+ .values(name="n1", value="ValUE1")
+ .returning(table1)
+ )
+ else:
+ stmt = (
+ insert(table1)
+ .values(name="n1", value="ValUE1")
+ .returning(table1.c.name, table1.c.value)
+ )
+
+ result = connection.execute(stmt)
+ row = result.first()
+
+ eq_(row._mapping["name"], "n1")
+ eq_(row._mapping["value"], "value1")
+
@testing.fails_on_everything_except(
"postgresql", "mariadb>=10.5", "sqlite>=3.34"
)