diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 158 |
1 files changed, 134 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fa7eeaecf..8df93a60b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -36,6 +36,7 @@ from . import roles from . import schema from . import selectable from . import sqltypes +from .base import NO_ARG from .. import exc from .. import util @@ -463,14 +464,6 @@ class SQLCompiler(Compiled): columns with the table name (i.e. MySQL only) """ - contains_expanding_parameters = False - """True if we've encountered bindparam(..., expanding=True). - - These need to be converted before execution time against the - string statement. - - """ - ansi_bind_rules = False """SQL 92 doesn't allow bind parameters to be used in the columns clause of a SELECT, nor does it allow @@ -507,6 +500,8 @@ class SQLCompiler(Compiled): """ + literal_execute_params = frozenset() + insert_prefetch = update_prefetch = () def __init__( @@ -1267,6 +1262,81 @@ class SQLCompiler(Compiled): % self.dialect.name ) + def _literal_execute_expanding_parameter_literal_binds( + self, parameter, values + ): + if not values: + replacement_expression = self.visit_empty_set_expr( + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] + ) + + elif isinstance(values[0], (tuple, list)): + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.render_literal_value(value, parameter.type) + for value in tuple_element + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + replacement_expression = ", ".join( + self.render_literal_value(value, parameter.type) + for value in values + ) + + return (), replacement_expression + + def _literal_execute_expanding_parameter(self, name, parameter, values): + if parameter.literal_execute: + return self._literal_execute_expanding_parameter_literal_binds( + parameter, values + ) + + if not values: + to_update = [] + replacement_expression = self.visit_empty_set_expr( + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] + ) + + elif isinstance(values[0], (tuple, list)): + to_update = [ + ("%s_%s_%s" % (name, i, j), value) + for i, tuple_element in enumerate(values, 1) + for j, value in enumerate(tuple_element, 1) + ] + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.bindtemplate + % {"name": to_update[i * len(tuple_element) + j][0]} + for j, value in enumerate(tuple_element) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + to_update = [ + ("%s_%s" % (name, i), value) + for i, value in enumerate(values, 1) + ] + replacement_expression = ", ".join( + self.bindtemplate % {"name": key} for key, value in to_update + ) + + return to_update, replacement_expression + def visit_binary( self, binary, override_operator=None, eager_grouping=False, **kw ): @@ -1457,6 +1527,7 @@ class SQLCompiler(Compiled): within_columns_clause=False, literal_binds=False, skip_bind_expression=False, + literal_execute=False, **kwargs ): @@ -1469,18 +1540,28 @@ class SQLCompiler(Compiled): skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, + literal_execute=literal_execute, **kwargs ) - if literal_binds or (within_columns_clause and self.ansi_bind_rules): - if bindparam.value is None and bindparam.callable is None: - raise exc.CompileError( - "Bind parameter '%s' without a " - "renderable value not allowed here." % bindparam.key - ) - return self.render_literal_bindparam( + if not literal_binds: + post_compile = ( + literal_execute + or bindparam.literal_execute + or bindparam.expanding + ) + else: + post_compile = False + + if not literal_execute and ( + literal_binds or (within_columns_clause and self.ansi_bind_rules) + ): + ret = self.render_literal_bindparam( bindparam, within_columns_clause=True, **kwargs ) + if bindparam.expanding: + ret = "(%s)" % ret + return ret name = self._truncate_bindparam(bindparam) @@ -1508,13 +1589,38 @@ class SQLCompiler(Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - return self.bindparam_string( - name, expanding=bindparam.expanding, **kwargs + if post_compile: + self.literal_execute_params |= {bindparam} + + ret = self.bindparam_string( + name, + post_compile=post_compile, + expanding=bindparam.expanding, + **kwargs ) + if bindparam.expanding: + ret = "(%s)" % ret + return ret + + def render_literal_bindparam( + self, bindparam, render_literal_value=NO_ARG, **kw + ): + if render_literal_value is not NO_ARG: + value = render_literal_value + else: + if bindparam.value is None and bindparam.callable is None: + raise exc.CompileError( + "Bind parameter '%s' without a " + "renderable value not allowed here." % bindparam.key + ) + value = bindparam.effective_value - def render_literal_bindparam(self, bindparam, **kw): - value = bindparam.effective_value - return self.render_literal_value(value, bindparam.type) + if bindparam.expanding: + leep = self._literal_execute_expanding_parameter_literal_binds + to_update, replacement_expr = leep(bindparam, value) + return replacement_expr + else: + return self.render_literal_value(value, bindparam.type) def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. @@ -1577,16 +1683,20 @@ class SQLCompiler(Compiled): return derived + "_" + str(anonymous_counter) def bindparam_string( - self, name, positional_names=None, expanding=False, **kw + self, + name, + positional_names=None, + post_compile=False, + expanding=False, + **kw ): if self.positional: if positional_names is not None: positional_names.append(name) else: self.positiontup.append(name) - if expanding: - self.contains_expanding_parameters = True - return "([EXPANDING_%s])" % name + if post_compile: + return "[POSTCOMPILE_%s]" % name else: return self.bindtemplate % {"name": name} |