summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-11-06 17:15:30 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2014-11-06 17:15:30 -0500
commit590498bf844e7dcdcf41d3ac786b4cccbebd2d43 (patch)
tree0455eea3f8555a4b78ec7fa015b06d9ffc88d47f /lib/sqlalchemy/sql
parentb9d430af752b7cc955932a54a8f8db18f46d89a6 (diff)
parent8200c2cd35b3e85a636baabe8324b9ecbbd8fedf (diff)
downloadsqlalchemy-590498bf844e7dcdcf41d3ac786b4cccbebd2d43.tar.gz
Merge branch 'master' into ticket_3100
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py584
-rw-r--r--lib/sqlalchemy/sql/crud.py530
-rw-r--r--lib/sqlalchemy/sql/dml.py26
-rw-r--r--lib/sqlalchemy/sql/elements.py150
-rw-r--r--lib/sqlalchemy/sql/expression.py10
-rw-r--r--lib/sqlalchemy/sql/functions.py31
-rw-r--r--lib/sqlalchemy/sql/schema.py49
-rw-r--r--lib/sqlalchemy/sql/selectable.py138
9 files changed, 961 insertions, 558 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 4d013859c..351e08d0b 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -38,6 +38,7 @@ from .expression import (
false,
False_,
func,
+ funcfilter,
insert,
intersect,
intersect_all,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5149fa4fe..5fa78ad0f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -24,12 +24,10 @@ To generate user-defined SQL strings, see
"""
import re
-from . import schema, sqltypes, operators, functions, \
- util as sql_util, visitors, elements, selectable, base
+from . import schema, sqltypes, operators, functions, visitors, \
+ elements, selectable, crud
from .. import util, exc
-import decimal
import itertools
-import operator
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -64,17 +62,6 @@ BIND_TEMPLATES = {
'named': ":%(name)s"
}
-REQUIRED = util.symbol('REQUIRED', """
-Placeholder for the value within a :class:`.BindParameter`
-which is required to be present when the statement is passed
-to :meth:`.Connection.execute`.
-
-This symbol is typically used when a :func:`.expression.insert`
-or :func:`.expression.update` statement is compiled without parameter
-values present.
-
-""")
-
OPERATORS = {
# binary
@@ -725,7 +712,6 @@ class SQLCompiler(Compiled):
for c in clauselist.clauses)
if s)
-
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
@@ -760,6 +746,12 @@ class SQLCompiler(Compiled):
)
)
+ def visit_funcfilter(self, funcfilter, **kwargs):
+ return "%s FILTER (WHERE %s)" % (
+ funcfilter.func._compiler_dispatch(self, **kwargs),
+ funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ )
+
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
@@ -819,8 +811,9 @@ class SQLCompiler(Compiled):
text += " GROUP BY " + group_by
text += self.order_by_clause(cs, **kwargs)
- text += (cs._limit_clause is not None or cs._offset_clause is not None) and \
- self.limit_clause(cs) or ""
+ text += (cs._limit_clause is not None
+ or cs._offset_clause is not None) and \
+ self.limit_clause(cs, **kwargs) or ""
if self.ctes and \
compound_index == 0 and toplevel:
@@ -876,15 +869,15 @@ class SQLCompiler(Compiled):
isinstance(binary.right, elements.BindParameter):
kw['literal_binds'] = True
- operator = binary.operator
- disp = getattr(self, "visit_%s_binary" % operator.__name__, None)
+ operator_ = binary.operator
+ disp = getattr(self, "visit_%s_binary" % operator_.__name__, None)
if disp:
- return disp(binary, operator, **kw)
+ return disp(binary, operator_, **kw)
else:
try:
- opstring = OPERATORS[operator]
+ opstring = OPERATORS[operator_]
except KeyError:
- raise exc.UnsupportedCompilationError(self, operator)
+ raise exc.UnsupportedCompilationError(self, operator_)
else:
return self._generate_generic_binary(binary, opstring, **kw)
@@ -966,7 +959,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -977,7 +970,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -988,7 +981,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -999,7 +992,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
@@ -1331,6 +1324,9 @@ class SQLCompiler(Compiled):
def get_crud_hint_text(self, table, text):
return None
+ def get_statement_hint_text(self, hint_texts):
+ return " ".join(hint_texts)
+
def _transform_select_for_nested_joins(self, select):
"""Rewrite any "a JOIN (b JOIN c)" expression as
"a JOIN (select * from b JOIN c) AS anon", to support
@@ -1501,29 +1497,7 @@ class SQLCompiler(Compiled):
select, transformed_select)
return text
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
-
- if asfrom:
- froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
- else:
- froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
-
- new_correlate_froms = set(selectable._from_objects(*froms))
- all_correlate_froms = new_correlate_froms.union(correlate_froms)
-
- new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'iswrapper': iswrapper,
- 'correlate_froms': all_correlate_froms,
- 'selectable': select,
- }
- self.stack.append(new_entry)
+ froms = self._setup_select_stack(select, entry, asfrom, iswrapper)
column_clause_args = kwargs.copy()
column_clause_args.update({
@@ -1534,18 +1508,11 @@ class SQLCompiler(Compiled):
text = "SELECT " # we're off to a good start !
if select._hints:
- byfrom = dict([
- (from_, hinttext % {
- 'name': from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
- hint_text = self.get_select_hint_text(byfrom)
+ hint_text, byfrom = self._setup_select_hints(select)
if hint_text:
text += hint_text + " "
+ else:
+ byfrom = None
if select._prefixes:
text += self._generate_prefixes(
@@ -1566,6 +1533,70 @@ class SQLCompiler(Compiled):
if c is not None
]
+ text = self._compose_select_body(
+ text, select, inner_columns, froms, byfrom, kwargs)
+
+ if select._statement_hints:
+ per_dialect = [
+ ht for (dialect_name, ht)
+ in select._statement_hints
+ if dialect_name in ('*', self.dialect.name)
+ ]
+ if per_dialect:
+ text += " " + self.get_statement_hint_text(per_dialect)
+
+ if self.ctes and \
+ compound_index == 0 and toplevel:
+ text = self._render_cte_clause() + text
+
+ self.stack.pop(-1)
+
+ if asfrom and parens:
+ return "(" + text + ")"
+ else:
+ return text
+
+ def _setup_select_hints(self, select):
+ byfrom = dict([
+ (from_, hinttext % {
+ 'name': from_._compiler_dispatch(
+ self, ashint=True)
+ })
+ for (from_, dialect), hinttext in
+ select._hints.items()
+ if dialect in ('*', self.dialect.name)
+ ])
+ hint_text = self.get_select_hint_text(byfrom)
+ return hint_text, byfrom
+
+ def _setup_select_stack(self, select, entry, asfrom, iswrapper):
+ correlate_froms = entry['correlate_froms']
+ asfrom_froms = entry['asfrom_froms']
+
+ if asfrom:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms),
+ implicit_correlate_froms=())
+ else:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms)
+
+ new_correlate_froms = set(selectable._from_objects(*froms))
+ all_correlate_froms = new_correlate_froms.union(correlate_froms)
+
+ new_entry = {
+ 'asfrom_froms': new_correlate_froms,
+ 'iswrapper': iswrapper,
+ 'correlate_froms': all_correlate_froms,
+ 'selectable': select,
+ }
+ self.stack.append(new_entry)
+ return froms
+
+ def _compose_select_body(
+ self, text, select, inner_columns, froms, byfrom, kwargs):
text += ', '.join(inner_columns)
if froms:
@@ -1609,16 +1640,7 @@ class SQLCompiler(Compiled):
if select._for_update_arg is not None:
text += self.for_update_clause(select, **kwargs)
- if self.ctes and \
- compound_index == 0 and toplevel:
- text = self._render_cte_clause() + text
-
- self.stack.pop(-1)
-
- if asfrom and parens:
- return "(" + text + ")"
- else:
- return text
+ return text
def _generate_prefixes(self, stmt, prefixes, **kw):
clause = " ".join(
@@ -1708,9 +1730,9 @@ class SQLCompiler(Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, **kw)
+ crud_params = crud._get_crud_params(self, insert_stmt, **kw)
- if not colparams and \
+ if not crud_params and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The '%s' dialect with current database "
@@ -1725,9 +1747,9 @@ class SQLCompiler(Compiled):
"version settings does not support "
"in-place multirow inserts." %
self.dialect.name)
- colparams_single = colparams[0]
+ crud_params_single = crud_params[0]
else:
- colparams_single = colparams
+ crud_params_single = crud_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -1758,9 +1780,9 @@ class SQLCompiler(Compiled):
text += table_text
- if colparams_single or not supports_default_values:
+ if crud_params_single or not supports_default_values:
text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in colparams_single])
+ for c in crud_params_single])
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
@@ -1771,21 +1793,21 @@ class SQLCompiler(Compiled):
text += " " + returning_clause
if insert_stmt.select is not None:
- text += " %s" % self.process(insert_stmt.select, **kw)
- elif not colparams and supports_default_values:
+ text += " %s" % self.process(self._insert_from_select, **kw)
+ elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)" % (
- ', '.join(c[1] for c in colparam_set)
+ ', '.join(c[1] for c in crud_param_set)
)
- for colparam_set in colparams
+ for crud_param_set in crud_params
)
)
else:
text += " VALUES (%s)" % \
- ', '.join([c[1] for c in colparams])
+ ', '.join([c[1] for c in crud_params])
if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
@@ -1842,7 +1864,7 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- colparams = self._get_colparams(update_stmt, **kw)
+ crud_params = crud._get_crud_params(self, update_stmt, **kw)
if update_stmt._hints:
dialect_hints = dict([
@@ -1869,7 +1891,7 @@ class SQLCompiler(Compiled):
text += ', '.join(
c[0]._compiler_dispatch(self,
include_table=include_table) +
- '=' + c[1] for c in colparams
+ '=' + c[1] for c in crud_params
)
if self.returning or update_stmt._returning:
@@ -1905,380 +1927,9 @@ class SQLCompiler(Compiled):
return text
- def _create_crud_bind_param(self, col, value, required=False, name=None):
- if name is None:
- name = col.key
- bindparam = elements.BindParameter(name, value,
- type_=col.type, required=required)
- bindparam._is_crud = True
- return bindparam._compiler_dispatch(self)
-
@util.memoized_property
def _key_getters_for_crud_column(self):
- if self.isupdate and self.statement._extra_froms:
- # when extra tables are present, refer to the columns
- # in those extra tables as table-qualified, including in
- # dictionaries and when rendering bind param names.
- # the "main" table of the statement remains unqualified,
- # allowing the most compatibility with a non-multi-table
- # statement.
- _et = set(self.statement._extra_froms)
-
- def _column_as_key(key):
- str_key = elements._column_as_key(key)
- if hasattr(key, 'table') and key.table in _et:
- return (key.table.name, str_key)
- else:
- return str_key
-
- def _getattr_col_key(col):
- if col.table in _et:
- return (col.table.name, col.key)
- else:
- return col.key
-
- def _col_bind_name(col):
- if col.table in _et:
- return "%s_%s" % (col.table.name, col.key)
- else:
- return col.key
-
- else:
- _column_as_key = elements._column_as_key
- _getattr_col_key = _col_bind_name = operator.attrgetter("key")
-
- return _column_as_key, _getattr_col_key, _col_bind_name
-
- def _get_colparams(self, stmt, **kw):
- """create a set of tuples representing column/string pairs for use
- in an INSERT or UPDATE statement.
-
- Also generates the Compiled object's postfetch, prefetch, and
- returning column collections, used for default handling and ultimately
- populating the ResultProxy's prefetch_cols() and postfetch_cols()
- collections.
-
- """
-
- self.postfetch = []
- self.prefetch = []
- self.returning = []
-
- # no parameters in the statement, no parameters in the
- # compiled params - return binds for all columns
- if self.column_keys is None and stmt.parameters is None:
- return [
- (c, self._create_crud_bind_param(c,
- None, required=True))
- for c in stmt.table.columns
- ]
-
- if stmt._has_multi_parameters:
- stmt_parameters = stmt.parameters[0]
- else:
- stmt_parameters = stmt.parameters
-
- # getters - these are normally just column.key,
- # but in the case of mysql multi-table update, the rules for
- # .key must conditionally take tablename into account
- _column_as_key, _getattr_col_key, _col_bind_name = \
- self._key_getters_for_crud_column
-
- # if we have statement parameters - set defaults in the
- # compiled params
- if self.column_keys is None:
- parameters = {}
- else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in self.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
-
- # create a list of column assignment clauses as tuples
- values = []
-
- if stmt_parameters is not None:
- for k, v in stmt_parameters.items():
- colkey = _column_as_key(k)
- if colkey is not None:
- parameters.setdefault(colkey, v)
- else:
- # a non-Column expression on the left side;
- # add it to values() in an "as-is" state,
- # coercing right side to bound param
- if elements._is_literal(v):
- v = self.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
- else:
- v = self.process(v.self_group(), **kw)
-
- values.append((k, v))
-
- need_pks = self.isinsert and \
- not self.inline and \
- not stmt._returning and \
- not stmt._has_multi_parameters
-
- implicit_returning = need_pks and \
- self.dialect.implicit_returning and \
- stmt.table.implicit_returning
-
- if self.isinsert:
- implicit_return_defaults = (implicit_returning and
- stmt._return_defaults)
- elif self.isupdate:
- implicit_return_defaults = (self.dialect.implicit_returning and
- stmt.table.implicit_returning and
- stmt._return_defaults)
- else:
- implicit_return_defaults = False
-
- if implicit_return_defaults:
- if stmt._return_defaults is True:
- implicit_return_defaults = set(stmt.table.c)
- else:
- implicit_return_defaults = set(stmt._return_defaults)
-
- postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
-
- check_columns = {}
-
- # special logic that only occurs for multi-table UPDATE
- # statements
- if self.isupdate and stmt._extra_froms and stmt_parameters:
- normalized_params = dict(
- (elements._clause_element_as_expr(c), param)
- for c, param in stmt_parameters.items()
- )
- affected_tables = set()
- for t in stmt._extra_froms:
- for c in t.c:
- if c in normalized_params:
- affected_tables.add(t)
- check_columns[_getattr_col_key(c)] = c
- value = normalized_params[c]
- if elements._is_literal(value):
- value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED,
- name=_col_bind_name(c))
- else:
- self.postfetch.append(c)
- value = self.process(value.self_group(), **kw)
- values.append((c, value))
- # determine tables which are actually
- # to be updated - process onupdate and
- # server_onupdate for these
- for t in affected_tables:
- for c in t.c:
- if c in normalized_params:
- continue
- elif (c.onupdate is not None and not
- c.onupdate.is_sequence):
- if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
- )
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(
- c, None, name=_col_bind_name(c)
- )
- )
- )
- self.prefetch.append(c)
- elif c.server_onupdate is not None:
- self.postfetch.append(c)
-
- if self.isinsert and stmt.select_names:
- # for an insert from select, we can only use names that
- # are given, so only select for those names.
- cols = (stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names)
- else:
- # iterate through all table columns to maintain
- # ordering, even for those cols that aren't included
- cols = stmt.table.columns
-
- for c in cols:
- col_key = _getattr_col_key(c)
- if col_key in parameters and col_key not in check_columns:
- value = parameters.pop(col_key)
- if elements._is_literal(value):
- value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED,
- name=_col_bind_name(c)
- if not stmt._has_multi_parameters
- else "%s_0" % _col_bind_name(c)
- )
- else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
- value = value._clone()
- value.type = c.type
-
- if c.primary_key and implicit_returning:
- self.returning.append(c)
- value = self.process(value.self_group(), **kw)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- value = self.process(value.self_group(), **kw)
- else:
- self.postfetch.append(c)
- value = self.process(value.self_group(), **kw)
- values.append((c, value))
-
- elif self.isinsert:
- if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
-
- if implicit_returning:
- if c.default is not None:
- if c.default.is_sequence:
- if self.dialect.supports_sequences and \
- (not c.default.optional or
- not self.dialect.sequences_optional):
- proc = self.process(c.default, **kw)
- values.append((c, proc))
- self.returning.append(c)
- elif c.default.is_clause_element:
- values.append(
- (c, self.process(
- c.default.arg.self_group(), **kw))
- )
- self.returning.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- else:
- self.returning.append(c)
- else:
- if (
- (c.default is not None and
- (not c.default.is_sequence or
- self.dialect.supports_sequences)) or
- c is stmt.table._autoincrement_column and
- (self.dialect.supports_sequences or
- self.dialect.
- preexecute_autoincrement_sequences)
- ):
-
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
-
- self.prefetch.append(c)
-
- elif c.default is not None:
- if c.default.is_sequence:
- if self.dialect.supports_sequences and \
- (not c.default.optional or
- not self.dialect.sequences_optional):
- proc = self.process(c.default, **kw)
- values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- self.postfetch.append(c)
- elif c.default.is_clause_element:
- values.append(
- (c, self.process(
- c.default.arg.self_group(), **kw))
- )
-
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- # don't add primary key column to postfetch
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- elif c.server_default is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- self.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
-
- elif self.isupdate:
- if c.onupdate is not None and not c.onupdate.is_sequence:
- if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(
- c.onupdate.arg.self_group(), **kw))
- )
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- else:
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- elif c.server_onupdate is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- else:
- self.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
-
- if parameters and stmt_parameters:
- check = set(parameters).intersection(
- _column_as_key(k) for k in stmt.parameters
- ).difference(check_columns)
- if check:
- raise exc.CompileError(
- "Unconsumed column names: %s" %
- (", ".join("%s" % c for c in check))
- )
-
- if stmt._has_multi_parameters:
- values_0 = values
- values = [values]
-
- values.extend(
- [
- (
- c,
- (self._create_crud_bind_param(
- c, row[c.key],
- name="%s_%d" % (c.key, i + 1)
- ) if elements._is_literal(row[c.key])
- else self.process(
- row[c.key].self_group(), **kw))
- if c.key in row else param
- )
- for (c, param) in values_0
- ]
- for i, row in enumerate(stmt.parameters[1:])
- )
-
- return values
+ return crud._key_getters_for_crud_column(self)
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'correlate_froms': set([delete_stmt.table]),
@@ -2462,17 +2113,18 @@ class DDLCompiler(Compiled):
constraints.extend([c for c in table._sorted_constraints
if c is not table.primary_key])
- return ", \n\t".join(p for p in
- (self.process(constraint)
- for constraint in constraints
- if (
- constraint._create_rule is None or
- constraint._create_rule(self))
- and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
- )
+ return ", \n\t".join(
+ p for p in
+ (self.process(constraint)
+ for constraint in constraints
+ if (
+ constraint._create_rule is None or
+ constraint._create_rule(self))
+ and (
+ not self.dialect.supports_alter or
+ not getattr(constraint, 'use_alter', False)
+ )) if p is not None
+ )
def visit_drop_table(self, drop):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
new file mode 100644
index 000000000..831d05be1
--- /dev/null
+++ b/lib/sqlalchemy/sql/crud.py
@@ -0,0 +1,530 @@
+# sql/crud.py
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Functions used by compiler.py to determine the parameters rendered
+within INSERT and UPDATE statements.
+
+"""
+from .. import util
+from .. import exc
+from . import elements
+import operator
+
+REQUIRED = util.symbol('REQUIRED', """
+Placeholder for the value within a :class:`.BindParameter`
+which is required to be present when the statement is passed
+to :meth:`.Connection.execute`.
+
+This symbol is typically used when a :func:`.expression.insert`
+or :func:`.expression.update` statement is compiled without parameter
+values present.
+
+""")
+
+
+def _get_crud_params(compiler, stmt, **kw):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ Also generates the Compiled object's postfetch, prefetch, and
+ returning column collections, used for default handling and ultimately
+ populating the ResultProxy's prefetch_cols() and postfetch_cols()
+ collections.
+
+ """
+
+ compiler.postfetch = []
+ compiler.prefetch = []
+ compiler.returning = []
+
+ # no parameters in the statement, no parameters in the
+ # compiled params - return binds for all columns
+ if compiler.column_keys is None and stmt.parameters is None:
+ return [
+ (c, _create_bind_param(
+ compiler, c, None, required=True))
+ for c in stmt.table.columns
+ ]
+
+ if stmt._has_multi_parameters:
+ stmt_parameters = stmt.parameters[0]
+ else:
+ stmt_parameters = stmt.parameters
+
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ _column_as_key, _getattr_col_key, _col_bind_name = \
+ _key_getters_for_crud_column(compiler)
+
+ # if we have statement parameters - set defaults in the
+ # compiled params
+ if compiler.column_keys is None:
+ parameters = {}
+ else:
+ parameters = dict((_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or
+ key not in stmt_parameters)
+
+ # create a list of column assignment clauses as tuples
+ values = []
+
+ if stmt_parameters is not None:
+ _get_stmt_parameters_params(
+ compiler,
+ parameters, stmt_parameters, _column_as_key, values, kw)
+
+ check_columns = {}
+
+ # special logic that only occurs for multi-table UPDATE
+ # statements
+ if compiler.isupdate and stmt._extra_froms and stmt_parameters:
+ _get_multitable_params(
+ compiler, stmt, stmt_parameters, check_columns,
+ _col_bind_name, _getattr_col_key, values, kw)
+
+ if compiler.isinsert and stmt.select_names:
+ _scan_insert_from_select_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
+ else:
+ _scan_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
+
+ if parameters and stmt_parameters:
+ check = set(parameters).intersection(
+ _column_as_key(k) for k in stmt.parameters
+ ).difference(check_columns)
+ if check:
+ raise exc.CompileError(
+ "Unconsumed column names: %s" %
+ (", ".join("%s" % c for c in check))
+ )
+
+ if stmt._has_multi_parameters:
+ values = _extend_values_for_multiparams(compiler, stmt, values, kw)
+
+ return values
+
+
+def _create_bind_param(
+ compiler, col, value, process=True, required=False, name=None):
+ if name is None:
+ name = col.key
+ bindparam = elements.BindParameter(name, value,
+ type_=col.type, required=required)
+ bindparam._is_crud = True
+ if process:
+ bindparam = bindparam._compiler_dispatch(compiler)
+ return bindparam
+
+
+def _key_getters_for_crud_column(compiler):
+ if compiler.isupdate and compiler.statement._extra_froms:
+ # when extra tables are present, refer to the columns
+ # in those extra tables as table-qualified, including in
+ # dictionaries and when rendering bind param names.
+ # the "main" table of the statement remains unqualified,
+ # allowing the most compatibility with a non-multi-table
+ # statement.
+ _et = set(compiler.statement._extra_froms)
+
+ def _column_as_key(key):
+ str_key = elements._column_as_key(key)
+ if hasattr(key, 'table') and key.table in _et:
+ return (key.table.name, str_key)
+ else:
+ return str_key
+
+ def _getattr_col_key(col):
+ if col.table in _et:
+ return (col.table.name, col.key)
+ else:
+ return col.key
+
+ def _col_bind_name(col):
+ if col.table in _et:
+ return "%s_%s" % (col.table.name, col.key)
+ else:
+ return col.key
+
+ else:
+ _column_as_key = elements._column_as_key
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+ return _column_as_key, _getattr_col_key, _col_bind_name
+
+
+def _scan_insert_from_select_cols(
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
+
+ need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid = \
+ _get_returning_modifiers(compiler, stmt)
+
+ cols = [stmt.table.c[_column_as_key(name)]
+ for name in stmt.select_names]
+
+ compiler._insert_from_select = stmt.select
+
+ add_select_cols = []
+ if stmt.include_insert_from_select_defaults:
+ col_set = set(cols)
+ for col in stmt.table.columns:
+ if col not in col_set and col.default:
+ cols.append(col)
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ parameters.pop(col_key)
+ values.append((c, None))
+ else:
+ _append_param_insert_select_hasdefault(
+ compiler, stmt, c, add_select_cols, kw)
+
+ if add_select_cols:
+ values.extend(add_select_cols)
+ compiler._insert_from_select = compiler._insert_from_select._generate()
+ compiler._insert_from_select._raw_columns += tuple(
+ expr for col, expr in add_select_cols)
+
+
+def _scan_cols(
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
+
+ need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid = \
+ _get_returning_modifiers(compiler, stmt)
+
+ cols = stmt.table.columns
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+
+ _append_param_parameter(
+ compiler, stmt, c, col_key, parameters, _col_bind_name,
+ implicit_returning, implicit_return_defaults, values, kw)
+
+ elif compiler.isinsert:
+ if c.primary_key and \
+ need_pks and \
+ (
+ implicit_returning or
+ not postfetch_lastrowid or
+ c is not stmt.table._autoincrement_column
+ ):
+
+ if implicit_returning:
+ _append_param_insert_pk_returning(
+ compiler, stmt, c, values, kw)
+ else:
+ _append_param_insert_pk(compiler, stmt, c, values, kw)
+
+ elif c.default is not None:
+
+ _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults,
+ values, kw)
+
+ elif c.server_default is not None:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ elif compiler.isupdate:
+ _append_param_update(
+ compiler, stmt, c, implicit_return_defaults, values, kw)
+
+
+def _append_param_parameter(
+ compiler, stmt, c, col_key, parameters, _col_bind_name,
+ implicit_returning, implicit_return_defaults, values, kw):
+ value = parameters.pop(col_key)
+ if elements._is_literal(value):
+ value = _create_bind_param(
+ compiler, c, value, required=value is REQUIRED,
+ name=_col_bind_name(c)
+ if not stmt._has_multi_parameters
+ else "%s_0" % _col_bind_name(c)
+ )
+ else:
+ if isinstance(value, elements.BindParameter) and \
+ value.type._isnull:
+ value = value._clone()
+ value.type = c.type
+
+ if c.primary_key and implicit_returning:
+ compiler.returning.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ else:
+ compiler.postfetch.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ values.append((c, value))
+
+
+def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
+ if c.default is not None:
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = compiler.process(c.default, **kw)
+ values.append((c, proc))
+ compiler.returning.append(c)
+ elif c.default.is_clause_element:
+ values.append(
+ (c, compiler.process(
+ c.default.arg.self_group(), **kw))
+ )
+ compiler.returning.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+ compiler.prefetch.append(c)
+ else:
+ compiler.returning.append(c)
+
+
+def _append_param_insert_pk(compiler, stmt, c, values, kw):
+ if (
+ (c.default is not None and
+ (not c.default.is_sequence or
+ compiler.dialect.supports_sequences)) or
+ c is stmt.table._autoincrement_column and
+ (compiler.dialect.supports_sequences or
+ compiler.dialect.
+ preexecute_autoincrement_sequences)
+ ):
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+
+ compiler.prefetch.append(c)
+
+
+def _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = compiler.process(c.default, **kw)
+ values.append((c, proc))
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif c.default.is_clause_element:
+ proc = compiler.process(c.default.arg.self_group(), **kw)
+ values.append((c, proc))
+
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ # don't add primary key column to postfetch
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+ compiler.prefetch.append(c)
+
+
+def _append_param_insert_select_hasdefault(
+ compiler, stmt, c, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = c.default
+ values.append((c, proc))
+ elif c.default.is_clause_element:
+ proc = c.default.arg.self_group()
+ values.append((c, proc))
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None, process=False))
+ )
+ compiler.prefetch.append(c)
+
+
+def _append_param_update(
+ compiler, stmt, c, implicit_return_defaults, values, kw):
+
+ if c.onupdate is not None and not c.onupdate.is_sequence:
+ if c.onupdate.is_clause_element:
+ values.append(
+ (c, compiler.process(
+ c.onupdate.arg.self_group(), **kw))
+ )
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+ compiler.prefetch.append(c)
+ elif c.server_onupdate is not None:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+
+def _get_multitable_params(
+ compiler, stmt, stmt_parameters, check_columns,
+ _col_bind_name, _getattr_col_key, values, kw):
+
+ normalized_params = dict(
+ (elements._clause_element_as_expr(c), param)
+ for c, param in stmt_parameters.items()
+ )
+ affected_tables = set()
+ for t in stmt._extra_froms:
+ for c in t.c:
+ if c in normalized_params:
+ affected_tables.add(t)
+ check_columns[_getattr_col_key(c)] = c
+ value = normalized_params[c]
+ if elements._is_literal(value):
+ value = _create_bind_param(
+ compiler, c, value, required=value is REQUIRED,
+ name=_col_bind_name(c))
+ else:
+ compiler.postfetch.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ values.append((c, value))
+ # determine tables which are actually to be updated - process onupdate
+ # and server_onupdate for these
+ for t in affected_tables:
+ for c in t.c:
+ if c in normalized_params:
+ continue
+ elif (c.onupdate is not None and not
+ c.onupdate.is_sequence):
+ if c.onupdate.is_clause_element:
+ values.append(
+ (c, compiler.process(
+ c.onupdate.arg.self_group(),
+ **kw)
+ )
+ )
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(
+ compiler, c, None, name=_col_bind_name(c)
+ )
+ )
+ )
+ compiler.prefetch.append(c)
+ elif c.server_onupdate is not None:
+ compiler.postfetch.append(c)
+
+
+def _extend_values_for_multiparams(compiler, stmt, values, kw):
+ values_0 = values
+ values = [values]
+
+ values.extend(
+ [
+ (
+ c,
+ (_create_bind_param(
+ compiler, c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ ) if elements._is_literal(row[c.key])
+ else compiler.process(
+ row[c.key].self_group(), **kw))
+ if c.key in row else param
+ )
+ for (c, param) in values_0
+ ]
+ for i, row in enumerate(stmt.parameters[1:])
+ )
+ return values
+
+
+def _get_stmt_parameters_params(
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw):
+ for k, v in stmt_parameters.items():
+ colkey = _column_as_key(k)
+ if colkey is not None:
+ parameters.setdefault(colkey, v)
+ else:
+ # a non-Column expression on the left side;
+ # add it to values() in an "as-is" state,
+ # coercing right side to bound param
+ if elements._is_literal(v):
+ v = compiler.process(
+ elements.BindParameter(None, v, type_=k.type),
+ **kw)
+ else:
+ v = compiler.process(v.self_group(), **kw)
+
+ values.append((k, v))
+
+
+def _get_returning_modifiers(compiler, stmt):
+ need_pks = compiler.isinsert and \
+ not compiler.inline and \
+ not stmt._returning and \
+ not stmt._has_multi_parameters
+
+ implicit_returning = need_pks and \
+ compiler.dialect.implicit_returning and \
+ stmt.table.implicit_returning
+
+ if compiler.isinsert:
+ implicit_return_defaults = (implicit_returning and
+ stmt._return_defaults)
+ elif compiler.isupdate:
+ implicit_return_defaults = (compiler.dialect.implicit_returning and
+ stmt.table.implicit_returning and
+ stmt._return_defaults)
+ else:
+ implicit_return_defaults = False
+
+ if implicit_return_defaults:
+ if stmt._return_defaults is True:
+ implicit_return_defaults = set(stmt.table.c)
+ else:
+ implicit_return_defaults = set(stmt._return_defaults)
+
+ postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
+
+ return need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 1934d0776..9f2ce7ce3 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -475,6 +475,7 @@ class Insert(ValuesBase):
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self.select = self.select_names = None
+ self.include_insert_from_select_defaults = False
self.inline = inline
self._returning = returning
self._validate_dialect_kwargs(dialect_kw)
@@ -487,7 +488,7 @@ class Insert(ValuesBase):
return ()
@_generative
- def from_select(self, names, select):
+ def from_select(self, names, select, include_defaults=True):
"""Return a new :class:`.Insert` construct which represents
an ``INSERT...FROM SELECT`` statement.
@@ -506,6 +507,21 @@ class Insert(ValuesBase):
is not checked before passing along to the database, the database
would normally raise an exception if these column lists don't
correspond.
+ :param include_defaults: if True, non-server default values and
+ SQL expressions as specified on :class:`.Column` objects
+ (as documented in :ref:`metadata_defaults_toplevel`) not
+ otherwise specified in the list of names will be rendered
+ into the INSERT and SELECT statements, so that these values are also
+ included in the data to be inserted.
+
+ .. note:: A Python-side default that uses a Python callable function
+ will only be invoked **once** for the whole statement, and **not
+ per row**.
+
+ .. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders
+ Python-side and SQL expression column defaults into the
+ SELECT statement for columns otherwise not included in the
+ list of column names.
.. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
implies that the :paramref:`.insert.inline` flag is set to
@@ -514,13 +530,6 @@ class Insert(ValuesBase):
deals with an arbitrary number of rows, so the
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
- .. note::
-
- A SELECT..INSERT construct in SQL has no VALUES clause. Therefore
- :class:`.Column` objects which utilize Python-side defaults
- (e.g. as described at :ref:`metadata_defaults_toplevel`)
- will **not** take effect when using :meth:`.Insert.from_select`.
-
.. versionadded:: 0.8.3
"""
@@ -533,6 +542,7 @@ class Insert(ValuesBase):
self.select_names = names
self.inline = True
+ self.include_insert_from_select_defaults = include_defaults
self.select = _interpret_as_select(select)
def _copy_internals(self, clone=_clone, **kw):
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8ec0aa700..fa9b66024 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -228,6 +228,7 @@ class ClauseElement(Visitable):
is_selectable = False
is_clause_element = True
+ description = None
_order_by_label_element = None
_is_from_container = False
@@ -540,7 +541,7 @@ class ClauseElement(Visitable):
__nonzero__ = __bool__
def __repr__(self):
- friendly = getattr(self, 'description', None)
+ friendly = self.description
if friendly is None:
return object.__repr__(self)
else:
@@ -860,6 +861,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
expressions and function calls.
"""
+ while self._is_clone_of is not None:
+ self = self._is_clone_of
+
return _anonymous_label(
'%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))
)
@@ -1616,10 +1620,10 @@ class Null(ColumnElement):
return type_api.NULLTYPE
@classmethod
- def _singleton(cls):
+ def _instance(cls):
"""Return a constant :class:`.Null` construct."""
- return NULL
+ return Null()
def compare(self, other):
return isinstance(other, Null)
@@ -1640,11 +1644,11 @@ class False_(ColumnElement):
return type_api.BOOLEANTYPE
def _negate(self):
- return TRUE
+ return True_()
@classmethod
- def _singleton(cls):
- """Return a constant :class:`.False_` construct.
+ def _instance(cls):
+ """Return a :class:`.False_` construct.
E.g.::
@@ -1678,7 +1682,7 @@ class False_(ColumnElement):
"""
- return FALSE
+ return False_()
def compare(self, other):
return isinstance(other, False_)
@@ -1699,17 +1703,17 @@ class True_(ColumnElement):
return type_api.BOOLEANTYPE
def _negate(self):
- return FALSE
+ return False_()
@classmethod
def _ifnone(cls, other):
if other is None:
- return cls._singleton()
+ return cls._instance()
else:
return other
@classmethod
- def _singleton(cls):
+ def _instance(cls):
"""Return a constant :class:`.True_` construct.
E.g.::
@@ -1744,15 +1748,11 @@ class True_(ColumnElement):
"""
- return TRUE
+ return True_()
def compare(self, other):
return isinstance(other, True_)
-NULL = Null()
-FALSE = False_()
-TRUE = True_()
-
class ClauseList(ClauseElement):
"""Describe a list of clauses, separated by an operator.
@@ -2782,6 +2782,10 @@ class Grouping(ColumnElement):
return self
@property
+ def _key_label(self):
+ return self._label
+
+ @property
def _label(self):
return getattr(self.element, '_label', None) or self.anon_label
@@ -2888,6 +2892,120 @@ class Over(ColumnElement):
))
+class FunctionFilter(ColumnElement):
+ """Represent a function FILTER clause.
+
+ This is a special operator against aggregate and window functions,
+ which controls which rows are passed to it.
+ It's supported only by certain database backends.
+
+ Invocation of :class:`.FunctionFilter` is via
+ :meth:`.FunctionElement.filter`::
+
+ func.count(1).filter(True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.FunctionElement.filter`
+
+ """
+ __visit_name__ = 'funcfilter'
+
+ criterion = None
+
+ def __init__(self, func, *criterion):
+ """Produce a :class:`.FunctionFilter` object against a function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ E.g.::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), MyClass.name == 'some name')
+
+ Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.filter` method.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.FunctionElement.filter`
+
+
+ """
+ self.func = func
+ self.filter(*criterion)
+
+ def filter(self, *criterion):
+ """Produce an additional FILTER against the function.
+
+ This method adds additional criteria to the initial criteria
+ set up by :meth:`.FunctionElement.filter`.
+
+ Multiple criteria are joined together at SQL render time
+ via ``AND``.
+
+
+ """
+
+ for criterion in list(criterion):
+ criterion = _expression_literal_as_text(criterion)
+
+ if self.criterion is not None:
+ self.criterion = self.criterion & criterion
+ else:
+ self.criterion = criterion
+
+ return self
+
+ def over(self, partition_by=None, order_by=None):
+ """Produce an OVER clause against this filtered function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ The expression::
+
+ func.rank().filter(MyClass.y > 5).over(order_by='x')
+
+ is shorthand for::
+
+ from sqlalchemy import over, funcfilter
+ over(funcfilter(func.rank(), MyClass.y > 5), order_by='x')
+
+ See :func:`~.expression.over` for a full description.
+
+ """
+ return Over(self, partition_by=partition_by, order_by=order_by)
+
+ @util.memoized_property
+ def type(self):
+ return self.func.type
+
+ def get_children(self, **kwargs):
+ return [c for c in
+ (self.func, self.criterion)
+ if c is not None]
+
+ def _copy_internals(self, clone=_clone, **kw):
+ self.func = clone(self.func, **kw)
+ if self.criterion is not None:
+ self.criterion = clone(self.criterion, **kw)
+
+ @property
+ def _from_objects(self):
+ return list(itertools.chain(
+ *[c._from_objects for c in (self.func, self.criterion)
+ if c is not None]
+ ))
+
+
class Label(ColumnElement):
"""Represents a column label (AS).
@@ -3491,7 +3609,7 @@ def _string_or_unprintable(element):
else:
try:
return str(element)
- except:
+ except Exception:
return "unprintable element %r" % element
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index d96f048b9..2ffc5468c 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -36,7 +36,7 @@ from .elements import ClauseElement, ColumnElement,\
True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
Grouping, not_, \
collate, literal_column, between,\
- literal, outparam, type_coerce, ClauseList
+ literal, outparam, type_coerce, ClauseList, FunctionFilter
from .elements import SavepointClause, RollbackToSavepointClause, \
ReleaseSavepointClause
@@ -89,14 +89,16 @@ asc = public_factory(UnaryExpression._create_asc, ".expression.asc")
desc = public_factory(UnaryExpression._create_desc, ".expression.desc")
distinct = public_factory(
UnaryExpression._create_distinct, ".expression.distinct")
-true = public_factory(True_._singleton, ".expression.true")
-false = public_factory(False_._singleton, ".expression.false")
-null = public_factory(Null._singleton, ".expression.null")
+true = public_factory(True_._instance, ".expression.true")
+false = public_factory(False_._instance, ".expression.false")
+null = public_factory(Null._instance, ".expression.null")
join = public_factory(Join._create_join, ".expression.join")
outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
+funcfilter = public_factory(
+ FunctionFilter, ".expression.funcfilter")
# internal functions still being called from tests and the ORM,
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 7efb1e916..9280c7d60 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,7 +12,7 @@ from . import sqltypes, schema
from .base import Executable, ColumnCollection
from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter
+ Over, BindParameter, FunctionFilter
from .selectable import FromClause, Select, Alias
from . import operators
@@ -116,6 +116,35 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
return Over(self, partition_by=partition_by, order_by=order_by)
+ def filter(self, *criterion):
+ """Produce a FILTER clause against this function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ The expression::
+
+ func.count(1).filter(True)
+
+ is shorthand for::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :class:`.FunctionFilter`
+
+ :func:`.funcfilter`
+
+
+ """
+ if not criterion:
+ return self
+ return FunctionFilter(self, *criterion)
+
@property
def _from_objects(self):
return self.clauses._from_objects
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index d9fd37f92..96cabbf4f 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -412,8 +412,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
table.dispatch.after_parent_attach(table, metadata)
return table
except:
- metadata._remove_table(name, schema)
- raise
+ with util.safe_reraise():
+ metadata._remove_table(name, schema)
@property
@util.deprecated('0.9', 'Use ``table.schema.quote``')
@@ -728,7 +728,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
checkfirst=checkfirst)
def tometadata(self, metadata, schema=RETAIN_SCHEMA,
- referred_schema_fn=None):
+ referred_schema_fn=None, name=None):
"""Return a copy of this :class:`.Table` associated with a different
:class:`.MetaData`.
@@ -785,13 +785,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
.. versionadded:: 0.9.2
- """
+ :param name: optional string name indicating the target table name.
+ If not specified or None, the table name is retained. This allows
+ a :class:`.Table` to be copied to the same :class:`.MetaData` target
+ with a new name.
+
+ .. versionadded:: 1.0.0
+ """
+ if name is None:
+ name = self.name
if schema is RETAIN_SCHEMA:
schema = self.schema
elif schema is None:
schema = metadata.schema
- key = _get_table_key(self.name, schema)
+ key = _get_table_key(name, schema)
if key in metadata.tables:
util.warn("Table '%s' already exists within the given "
"MetaData - not copying." % self.description)
@@ -801,7 +809,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
for c in self.columns:
args.append(c.copy(schema=schema))
table = Table(
- self.name, metadata, schema=schema,
+ name, metadata, schema=schema,
*args, **self.kwargs
)
for c in self.constraints:
@@ -1061,8 +1069,8 @@ class Column(SchemaItem, ColumnClause):
conditionally rendered differently on different backends,
consider custom compilation rules for :class:`.CreateColumn`.
- ..versionadded:: 0.8.3 Added the ``system=True`` parameter to
- :class:`.Column`.
+ .. versionadded:: 0.8.3 Added the ``system=True`` parameter to
+ :class:`.Column`.
"""
@@ -1222,8 +1230,10 @@ class Column(SchemaItem, ColumnClause):
existing = getattr(self, 'table', None)
if existing is not None and existing is not table:
raise exc.ArgumentError(
- "Column object already assigned to Table '%s'" %
- existing.description)
+ "Column object '%s' already assigned to Table '%s'" % (
+ self.key,
+ existing.description
+ ))
if self.key in table._columns:
col = table._columns.get(self.key)
@@ -1547,7 +1557,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
)
return self._schema_item_copy(fk)
- def _get_colspec(self, schema=None):
+ def _get_colspec(self, schema=None, table_name=None):
"""Return a string based 'column specification' for this
:class:`.ForeignKey`.
@@ -1557,7 +1567,15 @@ class ForeignKey(DialectKWArgs, SchemaItem):
"""
if schema:
_schema, tname, colname = self._column_tokens
+ if table_name is not None:
+ tname = table_name
return "%s.%s.%s" % (schema, tname, colname)
+ elif table_name:
+ schema, tname, colname = self._column_tokens
+ if schema:
+ return "%s.%s.%s" % (schema, table_name, colname)
+ else:
+ return "%s.%s" % (table_name, colname)
elif self._table_column is not None:
return "%s.%s" % (
self._table_column.table.fullname, self._table_column.key)
@@ -2649,10 +2667,15 @@ class ForeignKeyConstraint(Constraint):
event.listen(table.metadata, "before_drop",
ddl.DropConstraint(self, on=supports_alter))
- def copy(self, schema=None, **kw):
+ def copy(self, schema=None, target_table=None, **kw):
fkc = ForeignKeyConstraint(
[x.parent.key for x in self._elements.values()],
- [x._get_colspec(schema=schema)
+ [x._get_colspec(
+ schema=schema,
+ table_name=target_table.name
+ if target_table is not None
+ and x._table_key() == x.parent.table.key
+ else None)
for x in self._elements.values()],
name=self.name,
onupdate=self.onupdate,
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 9e8cb3bc5..8198a6733 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -746,6 +746,33 @@ class Join(FromClause):
providing a "natural join".
"""
+ constraints = cls._joincond_scan_left_right(
+ a, a_subset, b, consider_as_foreign_keys)
+
+ if len(constraints) > 1:
+ cls._joincond_trim_constraints(
+ a, b, constraints, consider_as_foreign_keys)
+
+ if len(constraints) == 0:
+ if isinstance(b, FromGrouping):
+ hint = " Perhaps you meant to convert the right side to a "\
+ "subquery using alias()?"
+ else:
+ hint = ""
+ raise exc.NoForeignKeysError(
+ "Can't find any foreign key relationships "
+ "between '%s' and '%s'.%s" %
+ (a.description, b.description, hint))
+
+ crit = [(x == y) for x, y in list(constraints.values())[0]]
+ if len(crit) == 1:
+ return (crit[0])
+ else:
+ return and_(*crit)
+
+ @classmethod
+ def _joincond_scan_left_right(
+ cls, a, a_subset, b, consider_as_foreign_keys):
constraints = collections.defaultdict(list)
for left in (a_subset, a):
@@ -780,57 +807,41 @@ class Join(FromClause):
if nrte.table_name == b.name:
raise
else:
- # this is totally covered. can't get
- # coverage to mark it.
continue
if col is not None:
constraints[fk.constraint].append((col, fk.parent))
if constraints:
break
+ return constraints
+ @classmethod
+ def _joincond_trim_constraints(
+ cls, a, b, constraints, consider_as_foreign_keys):
+ # more than one constraint matched. narrow down the list
+ # to include just those FKCs that match exactly to
+ # "consider_as_foreign_keys".
+ if consider_as_foreign_keys:
+ for const in list(constraints):
+ if set(f.parent for f in const.elements) != set(
+ consider_as_foreign_keys):
+ del constraints[const]
+
+ # if still multiple constraints, but
+ # they all refer to the exact same end result, use it.
if len(constraints) > 1:
- # more than one constraint matched. narrow down the list
- # to include just those FKCs that match exactly to
- # "consider_as_foreign_keys".
- if consider_as_foreign_keys:
- for const in list(constraints):
- if set(f.parent for f in const.elements) != set(
- consider_as_foreign_keys):
- del constraints[const]
-
- # if still multiple constraints, but
- # they all refer to the exact same end result, use it.
- if len(constraints) > 1:
- dedupe = set(tuple(crit) for crit in constraints.values())
- if len(dedupe) == 1:
- key = list(constraints)[0]
- constraints = {key: constraints[key]}
-
- if len(constraints) != 1:
- raise exc.AmbiguousForeignKeysError(
- "Can't determine join between '%s' and '%s'; "
- "tables have more than one foreign key "
- "constraint relationship between them. "
- "Please specify the 'onclause' of this "
- "join explicitly." % (a.description, b.description))
-
- if len(constraints) == 0:
- if isinstance(b, FromGrouping):
- hint = " Perhaps you meant to convert the right side to a "\
- "subquery using alias()?"
- else:
- hint = ""
- raise exc.NoForeignKeysError(
- "Can't find any foreign key relationships "
- "between '%s' and '%s'.%s" %
- (a.description, b.description, hint))
-
- crit = [(x == y) for x, y in list(constraints.values())[0]]
- if len(crit) == 1:
- return (crit[0])
- else:
- return and_(*crit)
+ dedupe = set(tuple(crit) for crit in constraints.values())
+ if len(dedupe) == 1:
+ key = list(constraints)[0]
+ constraints = {key: constraints[key]}
+
+ if len(constraints) != 1:
+ raise exc.AmbiguousForeignKeysError(
+ "Can't determine join between '%s' and '%s'; "
+ "tables have more than one foreign key "
+ "constraint relationship between them. "
+ "Please specify the 'onclause' of this "
+ "join explicitly." % (a.description, b.description))
def select(self, whereclause=None, **kwargs):
"""Create a :class:`.Select` from this :class:`.Join`.
@@ -2153,6 +2164,7 @@ class Select(HasPrefixes, GenerativeSelect):
_prefixes = ()
_hints = util.immutabledict()
+ _statement_hints = ()
_distinct = False
_from_cloned = None
_correlate = ()
@@ -2525,10 +2537,30 @@ class Select(HasPrefixes, GenerativeSelect):
return self._get_display_froms()
+ def with_statement_hint(self, text, dialect_name='*'):
+ """add a statement hint to this :class:`.Select`.
+
+ This method is similar to :meth:`.Select.with_hint` except that
+ it does not require an individual table, and instead applies to the
+ statement as a whole.
+
+ Hints here are specific to the backend database and may include
+ directives such as isolation levels, file directives, fetch directives,
+ etc.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.Select.with_hint`
+
+ """
+ return self.with_hint(None, text, dialect_name)
+
@_generative
def with_hint(self, selectable, text, dialect_name='*'):
- """Add an indexing hint for the given selectable to this
- :class:`.Select`.
+ """Add an indexing or other executional context hint for the given
+ selectable to this :class:`.Select`.
The text of the hint is rendered in the appropriate
location for the database backend in use, relative
@@ -2540,7 +2572,7 @@ class Select(HasPrefixes, GenerativeSelect):
following::
select([mytable]).\\
- with_hint(mytable, "+ index(%(name)s ix_mytable)")
+ with_hint(mytable, "index(%(name)s ix_mytable)")
Would render SQL as::
@@ -2551,13 +2583,19 @@ class Select(HasPrefixes, GenerativeSelect):
and Sybase simultaneously::
select([mytable]).\\
- with_hint(
- mytable, "+ index(%(name)s ix_mytable)", 'oracle').\\
+ with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\\
with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
+ .. seealso::
+
+ :meth:`.Select.with_statement_hint`
+
"""
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ if selectable is None:
+ self._statement_hints += ((dialect_name, text), )
+ else:
+ self._hints = self._hints.union(
+ {(selectable, dialect_name): text})
@property
def type(self):