diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 459 |
1 files changed, 38 insertions, 421 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index abda31358..0bdc60b8c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -24,12 +24,10 @@ To generate user-defined SQL strings, see """ import re -from . import schema, sqltypes, operators, functions, \ - util as sql_util, visitors, elements, selectable, base +from . import schema, sqltypes, operators, functions, visitors, \ + elements, selectable, crud from .. import util, exc -import decimal import itertools -import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -64,17 +62,6 @@ BIND_TEMPLATES = { 'named': ":%(name)s" } -REQUIRED = util.symbol('REQUIRED', """ -Placeholder for the value within a :class:`.BindParameter` -which is required to be present when the statement is passed -to :meth:`.Connection.execute`. - -This symbol is typically used when a :func:`.expression.insert` -or :func:`.expression.update` statement is compiled without parameter -values present. - -""") - OPERATORS = { # binary @@ -725,7 +712,6 @@ class SQLCompiler(Compiled): for c in clauselist.clauses) if s) - def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: @@ -819,7 +805,8 @@ class SQLCompiler(Compiled): text += " GROUP BY " + group_by text += self.order_by_clause(cs, **kwargs) - text += (cs._limit_clause is not None or cs._offset_clause is not None) and \ + text += (cs._limit_clause is not None + or cs._offset_clause is not None) and \ self.limit_clause(cs) or "" if self.ctes and \ @@ -876,15 +863,15 @@ class SQLCompiler(Compiled): isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - operator = binary.operator - disp = getattr(self, "visit_%s_binary" % operator.__name__, None) + operator_ = binary.operator + disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) if disp: - return disp(binary, operator, **kw) + return disp(binary, operator_, **kw) else: try: - opstring = OPERATORS[operator] + opstring = OPERATORS[operator_] except KeyError: - raise exc.UnsupportedCompilationError(self, operator) + raise exc.UnsupportedCompilationError(self, operator_) else: return self._generate_generic_binary(binary, opstring, **kw) @@ -966,7 +953,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -977,7 +964,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -988,7 +975,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -999,7 +986,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) @@ -1720,9 +1707,9 @@ class SQLCompiler(Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - colparams = self._get_colparams(insert_stmt, **kw) + crud_params = crud._get_crud_params(self, insert_stmt, **kw) - if not colparams and \ + if not crud_params and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError("The '%s' dialect with current database " @@ -1737,9 +1724,9 @@ class SQLCompiler(Compiled): "version settings does not support " "in-place multirow inserts." % self.dialect.name) - colparams_single = colparams[0] + crud_params_single = crud_params[0] else: - colparams_single = colparams + crud_params_single = crud_params preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -1770,9 +1757,9 @@ class SQLCompiler(Compiled): text += table_text - if colparams_single or not supports_default_values: + if crud_params_single or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in colparams_single]) + for c in crud_params_single]) if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning @@ -1784,20 +1771,20 @@ class SQLCompiler(Compiled): if insert_stmt.select is not None: text += " %s" % self.process(insert_stmt.select, **kw) - elif not colparams and supports_default_values: + elif not crud_params and supports_default_values: text += " DEFAULT VALUES" elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( ", ".join( "(%s)" % ( - ', '.join(c[1] for c in colparam_set) + ', '.join(c[1] for c in crud_param_set) ) - for colparam_set in colparams + for crud_param_set in crud_params ) ) else: text += " VALUES (%s)" % \ - ', '.join([c[1] for c in colparams]) + ', '.join([c[1] for c in crud_params]) if self.returning and not self.returning_precedes_values: text += " " + returning_clause @@ -1854,7 +1841,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, **kw) + crud_params = crud._get_crud_params(self, update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1881,7 +1868,7 @@ class SQLCompiler(Compiled): text += ', '.join( c[0]._compiler_dispatch(self, include_table=include_table) + - '=' + c[1] for c in colparams + '=' + c[1] for c in crud_params ) if self.returning or update_stmt._returning: @@ -1917,380 +1904,9 @@ class SQLCompiler(Compiled): return text - def _create_crud_bind_param(self, col, value, required=False, name=None): - if name is None: - name = col.key - bindparam = elements.BindParameter(name, value, - type_=col.type, required=required) - bindparam._is_crud = True - return bindparam._compiler_dispatch(self) - @util.memoized_property def _key_getters_for_crud_column(self): - if self.isupdate and self.statement._extra_froms: - # when extra tables are present, refer to the columns - # in those extra tables as table-qualified, including in - # dictionaries and when rendering bind param names. - # the "main" table of the statement remains unqualified, - # allowing the most compatibility with a non-multi-table - # statement. - _et = set(self.statement._extra_froms) - - def _column_as_key(key): - str_key = elements._column_as_key(key) - if hasattr(key, 'table') and key.table in _et: - return (key.table.name, str_key) - else: - return str_key - - def _getattr_col_key(col): - if col.table in _et: - return (col.table.name, col.key) - else: - return col.key - - def _col_bind_name(col): - if col.table in _et: - return "%s_%s" % (col.table.name, col.key) - else: - return col.key - - else: - _column_as_key = elements._column_as_key - _getattr_col_key = _col_bind_name = operator.attrgetter("key") - - return _column_as_key, _getattr_col_key, _col_bind_name - - def _get_colparams(self, stmt, **kw): - """create a set of tuples representing column/string pairs for use - in an INSERT or UPDATE statement. - - Also generates the Compiled object's postfetch, prefetch, and - returning column collections, used for default handling and ultimately - populating the ResultProxy's prefetch_cols() and postfetch_cols() - collections. - - """ - - self.postfetch = [] - self.prefetch = [] - self.returning = [] - - # no parameters in the statement, no parameters in the - # compiled params - return binds for all columns - if self.column_keys is None and stmt.parameters is None: - return [ - (c, self._create_crud_bind_param(c, - None, required=True)) - for c in stmt.table.columns - ] - - if stmt._has_multi_parameters: - stmt_parameters = stmt.parameters[0] - else: - stmt_parameters = stmt.parameters - - # getters - these are normally just column.key, - # but in the case of mysql multi-table update, the rules for - # .key must conditionally take tablename into account - _column_as_key, _getattr_col_key, _col_bind_name = \ - self._key_getters_for_crud_column - - # if we have statement parameters - set defaults in the - # compiled params - if self.column_keys is None: - parameters = {} - else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in self.column_keys - if not stmt_parameters or - key not in stmt_parameters) - - # create a list of column assignment clauses as tuples - values = [] - - if stmt_parameters is not None: - for k, v in stmt_parameters.items(): - colkey = _column_as_key(k) - if colkey is not None: - parameters.setdefault(colkey, v) - else: - # a non-Column expression on the left side; - # add it to values() in an "as-is" state, - # coercing right side to bound param - if elements._is_literal(v): - v = self.process( - elements.BindParameter(None, v, type_=k.type), - **kw) - else: - v = self.process(v.self_group(), **kw) - - values.append((k, v)) - - need_pks = self.isinsert and \ - not self.inline and \ - not stmt._returning and \ - not stmt._has_multi_parameters - - implicit_returning = need_pks and \ - self.dialect.implicit_returning and \ - stmt.table.implicit_returning - - if self.isinsert: - implicit_return_defaults = (implicit_returning and - stmt._return_defaults) - elif self.isupdate: - implicit_return_defaults = (self.dialect.implicit_returning and - stmt.table.implicit_returning and - stmt._return_defaults) - else: - implicit_return_defaults = False - - if implicit_return_defaults: - if stmt._return_defaults is True: - implicit_return_defaults = set(stmt.table.c) - else: - implicit_return_defaults = set(stmt._return_defaults) - - postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid - - check_columns = {} - - # special logic that only occurs for multi-table UPDATE - # statements - if self.isupdate and stmt._extra_froms and stmt_parameters: - normalized_params = dict( - (elements._clause_element_as_expr(c), param) - for c, param in stmt_parameters.items() - ) - affected_tables = set() - for t in stmt._extra_froms: - for c in t.c: - if c in normalized_params: - affected_tables.add(t) - check_columns[_getattr_col_key(c)] = c - value = normalized_params[c] - if elements._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is REQUIRED, - name=_col_bind_name(c)) - else: - self.postfetch.append(c) - value = self.process(value.self_group(), **kw) - values.append((c, value)) - # determine tables which are actually - # to be updated - process onupdate and - # server_onupdate for these - for t in affected_tables: - for c in t.c: - if c in normalized_params: - continue - elif (c.onupdate is not None and not - c.onupdate.is_sequence): - if c.onupdate.is_clause_element: - values.append( - (c, self.process( - c.onupdate.arg.self_group(), - **kw) - ) - ) - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param( - c, None, name=_col_bind_name(c) - ) - ) - ) - self.prefetch.append(c) - elif c.server_onupdate is not None: - self.postfetch.append(c) - - if self.isinsert and stmt.select_names: - # for an insert from select, we can only use names that - # are given, so only select for those names. - cols = (stmt.table.c[_column_as_key(name)] - for name in stmt.select_names) - else: - # iterate through all table columns to maintain - # ordering, even for those cols that aren't included - cols = stmt.table.columns - - for c in cols: - col_key = _getattr_col_key(c) - if col_key in parameters and col_key not in check_columns: - value = parameters.pop(col_key) - if elements._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is REQUIRED, - name=_col_bind_name(c) - if not stmt._has_multi_parameters - else "%s_0" % _col_bind_name(c) - ) - else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: - value = value._clone() - value.type = c.type - - if c.primary_key and implicit_returning: - self.returning.append(c) - value = self.process(value.self_group(), **kw) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - value = self.process(value.self_group(), **kw) - else: - self.postfetch.append(c) - value = self.process(value.self_group(), **kw) - values.append((c, value)) - - elif self.isinsert: - if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): - - if implicit_returning: - if c.default is not None: - if c.default.is_sequence: - if self.dialect.supports_sequences and \ - (not c.default.optional or - not self.dialect.sequences_optional): - proc = self.process(c.default, **kw) - values.append((c, proc)) - self.returning.append(c) - elif c.default.is_clause_element: - values.append( - (c, self.process( - c.default.arg.self_group(), **kw)) - ) - self.returning.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - else: - self.returning.append(c) - else: - if ( - (c.default is not None and - (not c.default.is_sequence or - self.dialect.supports_sequences)) or - c is stmt.table._autoincrement_column and - (self.dialect.supports_sequences or - self.dialect. - preexecute_autoincrement_sequences) - ): - - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - - self.prefetch.append(c) - - elif c.default is not None: - if c.default.is_sequence: - if self.dialect.supports_sequences and \ - (not c.default.optional or - not self.dialect.sequences_optional): - proc = self.process(c.default, **kw) - values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - self.postfetch.append(c) - elif c.default.is_clause_element: - values.append( - (c, self.process( - c.default.arg.self_group(), **kw)) - ) - - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - # don't add primary key column to postfetch - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - elif c.server_default is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - self.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - - elif self.isupdate: - if c.onupdate is not None and not c.onupdate.is_sequence: - if c.onupdate.is_clause_element: - values.append( - (c, self.process( - c.onupdate.arg.self_group(), **kw)) - ) - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - else: - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - elif c.server_onupdate is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - else: - self.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - - if parameters and stmt_parameters: - check = set(parameters).intersection( - _column_as_key(k) for k in stmt.parameters - ).difference(check_columns) - if check: - raise exc.CompileError( - "Unconsumed column names: %s" % - (", ".join("%s" % c for c in check)) - ) - - if stmt._has_multi_parameters: - values_0 = values - values = [values] - - values.extend( - [ - ( - c, - (self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) if elements._is_literal(row[c.key]) - else self.process( - row[c.key].self_group(), **kw)) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) - ) - - return values + return crud._key_getters_for_crud_column(self) def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), @@ -2474,17 +2090,18 @@ class DDLCompiler(Compiled): constraints.extend([c for c in table._sorted_constraints if c is not table.primary_key]) - return ", \n\t".join(p for p in - (self.process(constraint) - for constraint in constraints - if ( - constraint._create_rule is None or - constraint._create_rule(self)) - and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None - ) + return ", \n\t".join( + p for p in + (self.process(constraint) + for constraint in constraints + if ( + constraint._create_rule is None or + constraint._create_rule(self)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None + ) def visit_drop_table(self, drop): return "\nDROP TABLE " + self.preparer.format_table(drop.element) |