summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorScott Dugas <scott.dugas@foundationdb.com>2014-11-03 14:54:51 -0500
committerScott Dugas <scott.dugas@foundationdb.com>2014-11-03 14:54:51 -0500
commitb31ab006897d2709442f9745faf0cac6e0de1713 (patch)
treea6b428e9ca7f1f67c5193581ecd82a83632eeb79 /lib/sqlalchemy/sql/compiler.py
parentebb9d57cb385f49becbf54c6f78647715ddd1c29 (diff)
parent7bf5ac9c1e814c999d4930941935e1d5cfd236bf (diff)
downloadsqlalchemy-b31ab006897d2709442f9745faf0cac6e0de1713.tar.gz
Merge branch 'master' into fdbsql-tests
Conflicts: lib/sqlalchemy/testing/exclusions.py
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py660
1 files changed, 158 insertions, 502 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 72dd11eaf..5fa78ad0f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -24,12 +24,10 @@ To generate user-defined SQL strings, see
"""
import re
-from . import schema, sqltypes, operators, functions, \
- util as sql_util, visitors, elements, selectable, base
+from . import schema, sqltypes, operators, functions, visitors, \
+ elements, selectable, crud
from .. import util, exc
-import decimal
import itertools
-import operator
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -64,17 +62,6 @@ BIND_TEMPLATES = {
'named': ":%(name)s"
}
-REQUIRED = util.symbol('REQUIRED', """
-Placeholder for the value within a :class:`.BindParameter`
-which is required to be present when the statement is passed
-to :meth:`.Connection.execute`.
-
-This symbol is typically used when a :func:`.expression.insert`
-or :func:`.expression.update` statement is compiled without parameter
-values present.
-
-""")
-
OPERATORS = {
# binary
@@ -503,7 +490,35 @@ class SQLCompiler(Compiled):
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
- def visit_label_reference(self, element, **kwargs):
+ def visit_label_reference(
+ self, element, within_columns_clause=False, **kwargs):
+ if self.stack and self.dialect.supports_simple_order_by_label:
+ selectable = self.stack[-1]['selectable']
+
+ with_cols, only_froms = selectable._label_resolve_dict
+ if within_columns_clause:
+ resolve_dict = only_froms
+ else:
+ resolve_dict = with_cols
+
+ # this can be None in the case that a _label_reference()
+ # were subject to a replacement operation, in which case
+ # the replacement of the Label element may have changed
+ # to something else like a ColumnClause expression.
+ order_by_elem = element.element._order_by_label_element
+
+ if order_by_elem is not None and order_by_elem.name in \
+ resolve_dict:
+
+ kwargs['render_label_as_label'] = \
+ element.element._order_by_label_element
+
+ return self.process(
+ element.element, within_columns_clause=within_columns_clause,
+ **kwargs)
+
+ def visit_textual_label_reference(
+ self, element, within_columns_clause=False, **kwargs):
if not self.stack:
# compiling the element outside of the context of a SELECT
return self.process(
@@ -511,19 +526,25 @@ class SQLCompiler(Compiled):
)
selectable = self.stack[-1]['selectable']
+ with_cols, only_froms = selectable._label_resolve_dict
+
try:
- col = selectable._label_resolve_dict[element.text]
+ if within_columns_clause:
+ col = only_froms[element.element]
+ else:
+ col = with_cols[element.element]
except KeyError:
# treat it like text()
util.warn_limited(
"Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.text))
+ util.ellipses_string(element.element))
return self.process(
element._text_clause
)
else:
kwargs['render_label_as_label'] = col
- return self.process(col, **kwargs)
+ return self.process(
+ col, within_columns_clause=within_columns_clause, **kwargs)
def visit_label(self, label,
add_to_result_map=None,
@@ -678,11 +699,7 @@ class SQLCompiler(Compiled):
else:
return "0"
- def visit_clauselist(self, clauselist, order_by_select=None, **kw):
- if order_by_select is not None:
- return self._order_by_clauselist(
- clauselist, order_by_select, **kw)
-
+ def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
sep = " "
@@ -695,27 +712,6 @@ class SQLCompiler(Compiled):
for c in clauselist.clauses)
if s)
- def _order_by_clauselist(self, clauselist, order_by_select, **kw):
- # look through raw columns collection for labels.
- # note that its OK we aren't expanding tables and other selectables
- # here; we can only add a label in the ORDER BY for an individual
- # label expression in the columns clause.
-
- raw_col = set(order_by_select._label_resolve_dict.keys())
-
- return ", ".join(
- s for s in
- (
- c._compiler_dispatch(
- self,
- render_label_as_label=c._order_by_label_element if
- c._order_by_label_element is not None and
- c._order_by_label_element._label in raw_col
- else None,
- **kw)
- for c in clauselist.clauses)
- if s)
-
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
@@ -750,6 +746,12 @@ class SQLCompiler(Compiled):
)
)
+ def visit_funcfilter(self, funcfilter, **kwargs):
+ return "%s FILTER (WHERE %s)" % (
+ funcfilter.func._compiler_dispatch(self, **kwargs),
+ funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ )
+
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
@@ -809,8 +811,9 @@ class SQLCompiler(Compiled):
text += " GROUP BY " + group_by
text += self.order_by_clause(cs, **kwargs)
- text += (cs._limit_clause is not None or cs._offset_clause is not None) and \
- self.limit_clause(cs) or ""
+ text += (cs._limit_clause is not None
+ or cs._offset_clause is not None) and \
+ self.limit_clause(cs, **kwargs) or ""
if self.ctes and \
compound_index == 0 and toplevel:
@@ -866,15 +869,15 @@ class SQLCompiler(Compiled):
isinstance(binary.right, elements.BindParameter):
kw['literal_binds'] = True
- operator = binary.operator
- disp = getattr(self, "visit_%s_binary" % operator.__name__, None)
+ operator_ = binary.operator
+ disp = getattr(self, "visit_%s_binary" % operator_.__name__, None)
if disp:
- return disp(binary, operator, **kw)
+ return disp(binary, operator_, **kw)
else:
try:
- opstring = OPERATORS[operator]
+ opstring = OPERATORS[operator_]
except KeyError:
- raise exc.UnsupportedCompilationError(self, operator)
+ raise exc.UnsupportedCompilationError(self, operator_)
else:
return self._generate_generic_binary(binary, opstring, **kw)
@@ -956,7 +959,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -967,7 +970,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -978,7 +981,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -989,7 +992,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
@@ -1321,6 +1324,9 @@ class SQLCompiler(Compiled):
def get_crud_hint_text(self, table, text):
return None
+ def get_statement_hint_text(self, hint_texts):
+ return " ".join(hint_texts)
+
def _transform_select_for_nested_joins(self, select):
"""Rewrite any "a JOIN (b JOIN c)" expression as
"a JOIN (select * from b JOIN c) AS anon", to support
@@ -1491,29 +1497,7 @@ class SQLCompiler(Compiled):
select, transformed_select)
return text
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
-
- if asfrom:
- froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
- else:
- froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
-
- new_correlate_froms = set(selectable._from_objects(*froms))
- all_correlate_froms = new_correlate_froms.union(correlate_froms)
-
- new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'iswrapper': iswrapper,
- 'correlate_froms': all_correlate_froms,
- 'selectable': select,
- }
- self.stack.append(new_entry)
+ froms = self._setup_select_stack(select, entry, asfrom, iswrapper)
column_clause_args = kwargs.copy()
column_clause_args.update({
@@ -1524,18 +1508,11 @@ class SQLCompiler(Compiled):
text = "SELECT " # we're off to a good start !
if select._hints:
- byfrom = dict([
- (from_, hinttext % {
- 'name': from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
- hint_text = self.get_select_hint_text(byfrom)
+ hint_text, byfrom = self._setup_select_hints(select)
if hint_text:
text += hint_text + " "
+ else:
+ byfrom = None
if select._prefixes:
text += self._generate_prefixes(
@@ -1556,6 +1533,70 @@ class SQLCompiler(Compiled):
if c is not None
]
+ text = self._compose_select_body(
+ text, select, inner_columns, froms, byfrom, kwargs)
+
+ if select._statement_hints:
+ per_dialect = [
+ ht for (dialect_name, ht)
+ in select._statement_hints
+ if dialect_name in ('*', self.dialect.name)
+ ]
+ if per_dialect:
+ text += " " + self.get_statement_hint_text(per_dialect)
+
+ if self.ctes and \
+ compound_index == 0 and toplevel:
+ text = self._render_cte_clause() + text
+
+ self.stack.pop(-1)
+
+ if asfrom and parens:
+ return "(" + text + ")"
+ else:
+ return text
+
+ def _setup_select_hints(self, select):
+ byfrom = dict([
+ (from_, hinttext % {
+ 'name': from_._compiler_dispatch(
+ self, ashint=True)
+ })
+ for (from_, dialect), hinttext in
+ select._hints.items()
+ if dialect in ('*', self.dialect.name)
+ ])
+ hint_text = self.get_select_hint_text(byfrom)
+ return hint_text, byfrom
+
+ def _setup_select_stack(self, select, entry, asfrom, iswrapper):
+ correlate_froms = entry['correlate_froms']
+ asfrom_froms = entry['asfrom_froms']
+
+ if asfrom:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms),
+ implicit_correlate_froms=())
+ else:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms)
+
+ new_correlate_froms = set(selectable._from_objects(*froms))
+ all_correlate_froms = new_correlate_froms.union(correlate_froms)
+
+ new_entry = {
+ 'asfrom_froms': new_correlate_froms,
+ 'iswrapper': iswrapper,
+ 'correlate_froms': all_correlate_froms,
+ 'selectable': select,
+ }
+ self.stack.append(new_entry)
+ return froms
+
+ def _compose_select_body(
+ self, text, select, inner_columns, froms, byfrom, kwargs):
text += ', '.join(inner_columns)
if froms:
@@ -1590,13 +1631,7 @@ class SQLCompiler(Compiled):
text += " \nHAVING " + t
if select._order_by_clause.clauses:
- if self.dialect.supports_simple_order_by_label:
- order_by_select = select
- else:
- order_by_select = None
-
- text += self.order_by_clause(
- select, order_by_select=order_by_select, **kwargs)
+ text += self.order_by_clause(select, **kwargs)
if (select._limit_clause is not None or
select._offset_clause is not None):
@@ -1605,16 +1640,7 @@ class SQLCompiler(Compiled):
if select._for_update_arg is not None:
text += self.for_update_clause(select, **kwargs)
- if self.ctes and \
- compound_index == 0 and toplevel:
- text = self._render_cte_clause() + text
-
- self.stack.pop(-1)
-
- if asfrom and parens:
- return "(" + text + ")"
- else:
- return text
+ return text
def _generate_prefixes(self, stmt, prefixes, **kw):
clause = " ".join(
@@ -1704,9 +1730,9 @@ class SQLCompiler(Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, **kw)
+ crud_params = crud._get_crud_params(self, insert_stmt, **kw)
- if not colparams and \
+ if not crud_params and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The '%s' dialect with current database "
@@ -1721,9 +1747,9 @@ class SQLCompiler(Compiled):
"version settings does not support "
"in-place multirow inserts." %
self.dialect.name)
- colparams_single = colparams[0]
+ crud_params_single = crud_params[0]
else:
- colparams_single = colparams
+ crud_params_single = crud_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -1754,9 +1780,9 @@ class SQLCompiler(Compiled):
text += table_text
- if colparams_single or not supports_default_values:
+ if crud_params_single or not supports_default_values:
text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in colparams_single])
+ for c in crud_params_single])
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
@@ -1767,21 +1793,21 @@ class SQLCompiler(Compiled):
text += " " + returning_clause
if insert_stmt.select is not None:
- text += " %s" % self.process(insert_stmt.select, **kw)
- elif not colparams and supports_default_values:
+ text += " %s" % self.process(self._insert_from_select, **kw)
+ elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)" % (
- ', '.join(c[1] for c in colparam_set)
+ ', '.join(c[1] for c in crud_param_set)
)
- for colparam_set in colparams
+ for crud_param_set in crud_params
)
)
else:
text += " VALUES (%s)" % \
- ', '.join([c[1] for c in colparams])
+ ', '.join([c[1] for c in crud_params])
if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
@@ -1838,7 +1864,7 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- colparams = self._get_colparams(update_stmt, **kw)
+ crud_params = crud._get_crud_params(self, update_stmt, **kw)
if update_stmt._hints:
dialect_hints = dict([
@@ -1865,7 +1891,7 @@ class SQLCompiler(Compiled):
text += ', '.join(
c[0]._compiler_dispatch(self,
include_table=include_table) +
- '=' + c[1] for c in colparams
+ '=' + c[1] for c in crud_params
)
if self.returning or update_stmt._returning:
@@ -1901,380 +1927,9 @@ class SQLCompiler(Compiled):
return text
- def _create_crud_bind_param(self, col, value, required=False, name=None):
- if name is None:
- name = col.key
- bindparam = elements.BindParameter(name, value,
- type_=col.type, required=required)
- bindparam._is_crud = True
- return bindparam._compiler_dispatch(self)
-
@util.memoized_property
def _key_getters_for_crud_column(self):
- if self.isupdate and self.statement._extra_froms:
- # when extra tables are present, refer to the columns
- # in those extra tables as table-qualified, including in
- # dictionaries and when rendering bind param names.
- # the "main" table of the statement remains unqualified,
- # allowing the most compatibility with a non-multi-table
- # statement.
- _et = set(self.statement._extra_froms)
-
- def _column_as_key(key):
- str_key = elements._column_as_key(key)
- if hasattr(key, 'table') and key.table in _et:
- return (key.table.name, str_key)
- else:
- return str_key
-
- def _getattr_col_key(col):
- if col.table in _et:
- return (col.table.name, col.key)
- else:
- return col.key
-
- def _col_bind_name(col):
- if col.table in _et:
- return "%s_%s" % (col.table.name, col.key)
- else:
- return col.key
-
- else:
- _column_as_key = elements._column_as_key
- _getattr_col_key = _col_bind_name = operator.attrgetter("key")
-
- return _column_as_key, _getattr_col_key, _col_bind_name
-
- def _get_colparams(self, stmt, **kw):
- """create a set of tuples representing column/string pairs for use
- in an INSERT or UPDATE statement.
-
- Also generates the Compiled object's postfetch, prefetch, and
- returning column collections, used for default handling and ultimately
- populating the ResultProxy's prefetch_cols() and postfetch_cols()
- collections.
-
- """
-
- self.postfetch = []
- self.prefetch = []
- self.returning = []
-
- # no parameters in the statement, no parameters in the
- # compiled params - return binds for all columns
- if self.column_keys is None and stmt.parameters is None:
- return [
- (c, self._create_crud_bind_param(c,
- None, required=True))
- for c in stmt.table.columns
- ]
-
- if stmt._has_multi_parameters:
- stmt_parameters = stmt.parameters[0]
- else:
- stmt_parameters = stmt.parameters
-
- # getters - these are normally just column.key,
- # but in the case of mysql multi-table update, the rules for
- # .key must conditionally take tablename into account
- _column_as_key, _getattr_col_key, _col_bind_name = \
- self._key_getters_for_crud_column
-
- # if we have statement parameters - set defaults in the
- # compiled params
- if self.column_keys is None:
- parameters = {}
- else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in self.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
-
- # create a list of column assignment clauses as tuples
- values = []
-
- if stmt_parameters is not None:
- for k, v in stmt_parameters.items():
- colkey = _column_as_key(k)
- if colkey is not None:
- parameters.setdefault(colkey, v)
- else:
- # a non-Column expression on the left side;
- # add it to values() in an "as-is" state,
- # coercing right side to bound param
- if elements._is_literal(v):
- v = self.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
- else:
- v = self.process(v.self_group(), **kw)
-
- values.append((k, v))
-
- need_pks = self.isinsert and \
- not self.inline and \
- not stmt._returning and \
- not stmt._has_multi_parameters
-
- implicit_returning = need_pks and \
- self.dialect.implicit_returning and \
- stmt.table.implicit_returning
-
- if self.isinsert:
- implicit_return_defaults = (implicit_returning and
- stmt._return_defaults)
- elif self.isupdate:
- implicit_return_defaults = (self.dialect.implicit_returning and
- stmt.table.implicit_returning and
- stmt._return_defaults)
- else:
- implicit_return_defaults = False
-
- if implicit_return_defaults:
- if stmt._return_defaults is True:
- implicit_return_defaults = set(stmt.table.c)
- else:
- implicit_return_defaults = set(stmt._return_defaults)
-
- postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
-
- check_columns = {}
-
- # special logic that only occurs for multi-table UPDATE
- # statements
- if self.isupdate and stmt._extra_froms and stmt_parameters:
- normalized_params = dict(
- (elements._clause_element_as_expr(c), param)
- for c, param in stmt_parameters.items()
- )
- affected_tables = set()
- for t in stmt._extra_froms:
- for c in t.c:
- if c in normalized_params:
- affected_tables.add(t)
- check_columns[_getattr_col_key(c)] = c
- value = normalized_params[c]
- if elements._is_literal(value):
- value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED,
- name=_col_bind_name(c))
- else:
- self.postfetch.append(c)
- value = self.process(value.self_group(), **kw)
- values.append((c, value))
- # determine tables which are actually
- # to be updated - process onupdate and
- # server_onupdate for these
- for t in affected_tables:
- for c in t.c:
- if c in normalized_params:
- continue
- elif (c.onupdate is not None and not
- c.onupdate.is_sequence):
- if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
- )
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(
- c, None, name=_col_bind_name(c)
- )
- )
- )
- self.prefetch.append(c)
- elif c.server_onupdate is not None:
- self.postfetch.append(c)
-
- if self.isinsert and stmt.select_names:
- # for an insert from select, we can only use names that
- # are given, so only select for those names.
- cols = (stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names)
- else:
- # iterate through all table columns to maintain
- # ordering, even for those cols that aren't included
- cols = stmt.table.columns
-
- for c in cols:
- col_key = _getattr_col_key(c)
- if col_key in parameters and col_key not in check_columns:
- value = parameters.pop(col_key)
- if elements._is_literal(value):
- value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED,
- name=_col_bind_name(c)
- if not stmt._has_multi_parameters
- else "%s_0" % _col_bind_name(c)
- )
- else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
- value = value._clone()
- value.type = c.type
-
- if c.primary_key and implicit_returning:
- self.returning.append(c)
- value = self.process(value.self_group(), **kw)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- value = self.process(value.self_group(), **kw)
- else:
- self.postfetch.append(c)
- value = self.process(value.self_group(), **kw)
- values.append((c, value))
-
- elif self.isinsert:
- if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
-
- if implicit_returning:
- if c.default is not None:
- if c.default.is_sequence:
- if self.dialect.supports_sequences and \
- (not c.default.optional or
- not self.dialect.sequences_optional):
- proc = self.process(c.default, **kw)
- values.append((c, proc))
- self.returning.append(c)
- elif c.default.is_clause_element:
- values.append(
- (c, self.process(
- c.default.arg.self_group(), **kw))
- )
- self.returning.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- else:
- self.returning.append(c)
- else:
- if (
- (c.default is not None and
- (not c.default.is_sequence or
- self.dialect.supports_sequences)) or
- c is stmt.table._autoincrement_column and
- (self.dialect.supports_sequences or
- self.dialect.
- preexecute_autoincrement_sequences)
- ):
-
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
-
- self.prefetch.append(c)
-
- elif c.default is not None:
- if c.default.is_sequence:
- if self.dialect.supports_sequences and \
- (not c.default.optional or
- not self.dialect.sequences_optional):
- proc = self.process(c.default, **kw)
- values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- self.postfetch.append(c)
- elif c.default.is_clause_element:
- values.append(
- (c, self.process(
- c.default.arg.self_group(), **kw))
- )
-
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- # don't add primary key column to postfetch
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- elif c.server_default is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- self.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
-
- elif self.isupdate:
- if c.onupdate is not None and not c.onupdate.is_sequence:
- if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(
- c.onupdate.arg.self_group(), **kw))
- )
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- else:
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- elif c.server_onupdate is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- else:
- self.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
-
- if parameters and stmt_parameters:
- check = set(parameters).intersection(
- _column_as_key(k) for k in stmt.parameters
- ).difference(check_columns)
- if check:
- raise exc.CompileError(
- "Unconsumed column names: %s" %
- (", ".join("%s" % c for c in check))
- )
-
- if stmt._has_multi_parameters:
- values_0 = values
- values = [values]
-
- values.extend(
- [
- (
- c,
- (self._create_crud_bind_param(
- c, row[c.key],
- name="%s_%d" % (c.key, i + 1)
- ) if elements._is_literal(row[c.key])
- else self.process(
- row[c.key].self_group(), **kw))
- if c.key in row else param
- )
- for (c, param) in values_0
- ]
- for i, row in enumerate(stmt.parameters[1:])
- )
-
- return values
+ return crud._key_getters_for_crud_column(self)
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'correlate_froms': set([delete_stmt.table]),
@@ -2458,17 +2113,18 @@ class DDLCompiler(Compiled):
constraints.extend([c for c in table._sorted_constraints
if c is not table.primary_key])
- return ", \n\t".join(p for p in
- (self.process(constraint)
- for constraint in constraints
- if (
- constraint._create_rule is None or
- constraint._create_rule(self))
- and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
- )
+ return ", \n\t".join(
+ p for p in
+ (self.process(constraint)
+ for constraint in constraints
+ if (
+ constraint._create_rule is None or
+ constraint._create_rule(self))
+ and (
+ not self.dialect.supports_alter or
+ not getattr(constraint, 'use_alter', False)
+ )) if p is not None
+ )
def visit_drop_table(self, drop):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)