summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r--lib/sqlalchemy/sql/expression.py200
1 files changed, 87 insertions, 113 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index d937d0507..52291c487 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -29,7 +29,8 @@ import itertools, re
from operator import attrgetter
from sqlalchemy import util, exc
-from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import operators
+from sqlalchemy.sql.visitors import Visitable, cloned_traverse
from sqlalchemy import types as sqltypes
functions, schema, sql_util = None, None, None
@@ -876,7 +877,7 @@ def _compound_select(keyword, *selects, **kwargs):
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, ClauseElement)
+ return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__')
def _from_objects(*elements, **kwargs):
return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
@@ -890,7 +891,7 @@ def _labeled(element):
def _literal_as_text(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
return _TextClause(unicode(element))
else:
return element
@@ -898,7 +899,7 @@ def _literal_as_text(element):
def _literal_as_column(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
return literal_column(str(element))
else:
return element
@@ -906,7 +907,7 @@ def _literal_as_column(element):
def _literal_as_binds(element, name=None, type_=None):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
if element is None:
return null()
else:
@@ -917,15 +918,18 @@ def _literal_as_binds(element, name=None, type_=None):
def _no_literals(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
- raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+ elif not isinstance(element, Visitable):
+ raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function "
+ "to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
else:
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
c = fromclause.corresponding_column(column, require_embedded=require_embedded)
if not c:
- raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+ raise exc.InvalidRequestError("Given column '%s', attached to table '%s', "
+ "failed to locate a corresponding column from table '%s'"
+ % (column, getattr(column, 'table', None), fromclause.description))
return c
def _selectable(element):
@@ -934,39 +938,15 @@ def _selectable(element):
elif isinstance(element, Selectable):
return element
else:
- raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+ raise exc.ArgumentError("Object %r is not a Selectable and does not implement `__selectable__()`" % element)
def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
-class _FigureVisitName(type):
- def __init__(cls, clsname, bases, dict):
- if not '__visit_name__' in cls.__dict__:
- m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
- x = m.group(1)
- x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
- cls.__visit_name__ = x.lower()
-
- # set up an optimized visit dispatch function
- # for use by the compiler
- visit_name = cls.__dict__["__visit_name__"]
- if isinstance(visit_name, str):
- func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
- " return visitor.visit_%s(self, **kw)" % visit_name
- else:
- func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
- " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
- env = locals().copy()
- exec func_text in env
- cls._compiler_dispatch = env['_compiler_dispatch']
-
- super(_FigureVisitName, cls).__init__(clsname, bases, dict)
-
-class ClauseElement(object):
+class ClauseElement(Visitable):
"""Base class for elements of a programmatically constructed SQL expression."""
- __metaclass__ = _FigureVisitName
_annotations = {}
supports_execution = False
@@ -976,6 +956,7 @@ class ClauseElement(object):
This method may be used by a generative API. Its also used as
part of the "deep" copy afforded by a traversal that combines
the _copy_internals() method.
+
"""
c = self.__class__.__new__(self.__class__)
c.__dict__ = self.__dict__.copy()
@@ -1001,8 +982,8 @@ class ClauseElement(object):
should be added to the ``FROM`` list of a query, when this
``ClauseElement`` is placed in the column clause of a
``Select`` statement.
+
"""
-
raise NotImplementedError(repr(self))
def _annotate(self, values):
@@ -1049,7 +1030,7 @@ class ClauseElement(object):
bind.value = kwargs[bind.key]
if unique:
bind._convert_to_unique()
- return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
+ return cloned_traverse(self, {}, {'bindparam':visit_bindparam})
def compare(self, other):
"""Compare this ClauseElement to the given ClauseElement.
@@ -2637,14 +2618,13 @@ class _ColumnClause(_Immutable, ColumnElement):
rules applied regardless of case sensitive settings. the
``literal_column()`` function is usually used to create such a
``_ColumnClause``.
+
"""
-
def __init__(self, text, selectable=None, type_=None, is_literal=False):
ColumnElement.__init__(self)
self.key = self.name = text
self.table = selectable
self.type = sqltypes.to_instance(type_)
- self.__label = None
self.is_literal = is_literal
@util.memoized_property
@@ -2655,23 +2635,25 @@ class _ColumnClause(_Immutable, ColumnElement):
def _label(self):
if self.is_literal:
return None
- if not self.__label:
- if self.table and self.table.named_with_column:
- if getattr(self.table, 'schema', None):
- self.__label = self.table.schema + "_" + self.table.name + "_" + self.name
- else:
- self.__label = self.table.name + "_" + self.name
-
- if self.__label in self.table.c:
- label = self.__label
- counter = 1
- while label in self.table.c:
- label = self.__label + "_" + str(counter)
- counter += 1
- self.__label = label
+
+ elif self.table and self.table.named_with_column:
+ if getattr(self.table, 'schema', None):
+ label = self.table.schema + "_" + self.table.name + "_" + self.name
else:
- self.__label = self.name
- return self.__label
+ label = self.table.name + "_" + self.name
+
+ if label in self.table.c:
+ # TODO: coverage does not seem to be present for this
+ _label = label
+ counter = 1
+ while _label in self.table.c:
+ _label = label + "_" + str(counter)
+ counter += 1
+ label = _label
+ return label
+
+ else:
+ return self.name
def label(self, name):
if name is None:
@@ -2723,7 +2705,7 @@ class TableClause(_Immutable, FromClause):
def _export_columns(self):
raise NotImplementedError()
- @property
+ @util.memoized_property
def description(self):
return self.name.encode('ascii', 'backslashreplace')
@@ -2756,6 +2738,14 @@ class TableClause(_Immutable, FromClause):
def _get_from_objects(self, **modifiers):
return [self]
+@util.decorator
+def _generative(fn, *args, **kw):
+ """Mark a method as generative."""
+
+ self = args[0]._generate()
+ fn(self, *args[1:], **kw)
+ return self
+
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
@@ -2784,6 +2774,7 @@ class _SelectBaseMixin(object):
"""
return _ScalarSelect(self)
+ @_generative
def apply_labels(self):
"""return a new selectable with the 'use_labels' flag set to True.
@@ -2793,9 +2784,7 @@ class _SelectBaseMixin(object):
among the individual FROM clauses.
"""
- s = self._generate()
- s.use_labels = True
- return s
+ self.use_labels = True
def label(self, name):
"""return a 'scalar' representation of this selectable, embedded as a subquery
@@ -2806,12 +2795,11 @@ class _SelectBaseMixin(object):
"""
return self.as_scalar().label(name)
+ @_generative
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to True."""
- s = self._generate()
- s._autocommit = True
- return s
+ self._autocommit = True
def _generate(self):
s = self.__class__.__new__(self.__class__)
@@ -2819,39 +2807,35 @@ class _SelectBaseMixin(object):
s._reset_exported()
return s
+ @_generative
def limit(self, limit):
"""return a new selectable with the given LIMIT criterion applied."""
- s = self._generate()
- s._limit = limit
- return s
+ self._limit = limit
+ @_generative
def offset(self, offset):
"""return a new selectable with the given OFFSET criterion applied."""
- s = self._generate()
- s._offset = offset
- return s
+ self._offset = offset
+ @_generative
def order_by(self, *clauses):
"""return a new selectable with the given list of ORDER BY criterion applied.
The criterion will be appended to any pre-existing ORDER BY criterion.
"""
- s = self._generate()
- s.append_order_by(*clauses)
- return s
+ self.append_order_by(*clauses)
+ @_generative
def group_by(self, *clauses):
"""return a new selectable with the given list of GROUP BY criterion applied.
The criterion will be appended to any pre-existing GROUP BY criterion.
"""
- s = self._generate()
- s.append_group_by(*clauses)
- return s
+ self.append_group_by(*clauses)
def append_order_by(self, *clauses):
"""Append the given ORDER BY criterion applied to this selectable.
@@ -3112,72 +3096,67 @@ class Select(_SelectBaseMixin, FromClause):
self._raw_columns + list(self._froms) + \
[x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+ @_generative
def column(self, column):
"""return a new select() construct with the given column expression added to its columns clause."""
- s = self._generate()
column = _literal_as_column(column)
if isinstance(column, _ScalarSelect):
column = column.self_group(against=operators.comma_op)
- s._raw_columns = s._raw_columns + [column]
- s._froms = s._froms.union(_from_objects(column))
- return s
+ self._raw_columns = self._raw_columns + [column]
+ self._froms = self._froms.union(_from_objects(column))
+ @_generative
def with_only_columns(self, columns):
"""return a new select() construct with its columns clause replaced with the given columns."""
- s = self._generate()
- s._raw_columns = [
+
+ self._raw_columns = [
isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
for c in
[_literal_as_column(c) for c in columns]
]
- return s
+ @_generative
def where(self, whereclause):
"""return a new select() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- s.append_whereclause(whereclause)
- return s
+ self.append_whereclause(whereclause)
+ @_generative
def having(self, having):
"""return a new select() construct with the given expression added to its HAVING clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- s.append_having(having)
- return s
+ self.append_having(having)
+ @_generative
def distinct(self):
"""return a new select() construct which will apply DISTINCT to its columns clause."""
- s = self._generate()
- s._distinct = True
- return s
+ self._distinct = True
+ @_generative
def prefix_with(self, clause):
"""return a new select() construct which will apply the given expression to the start of its
columns clause, not using any commas."""
- s = self._generate()
clause = _literal_as_text(clause)
- s._prefixes = s._prefixes + [clause]
- return s
+ self._prefixes = self._prefixes + [clause]
+ @_generative
def select_from(self, fromclause):
"""return a new select() construct with the given FROM expression applied to its list of
FROM objects."""
- s = self._generate()
if _is_literal(fromclause):
fromclause = _TextClause(fromclause)
- s._froms = s._froms.union([fromclause])
- return s
+ self._froms = self._froms.union([fromclause])
+ @_generative
def correlate(self, *fromclauses):
"""return a new select() construct which will correlate the given FROM clauses to that
of an enclosing select(), if a match is found.
@@ -3192,13 +3171,11 @@ class Select(_SelectBaseMixin, FromClause):
If the fromclause is None, correlation is disabled for the returned select().
"""
- s = self._generate()
- s._should_correlate = False
+ self._should_correlate = False
if fromclauses == (None,):
- s._correlate = set()
+ self._correlate = set()
else:
- s._correlate = s._correlate.union(fromclauses)
- return s
+ self._correlate = self._correlate.union(fromclauses)
def append_correlation(self, fromclause):
"""append the given correlation expression to this select() construct."""
@@ -3416,16 +3393,15 @@ class Insert(_ValuesBase):
def _copy_internals(self, clone=_clone):
self.parameters = self.parameters.copy()
+ @_generative
def prefix_with(self, clause):
"""Add a word or expression between INSERT and INTO. Generative.
If multiple prefixes are supplied, they will be separated with
spaces.
"""
- gen = self._generate()
clause = _literal_as_text(clause)
- gen._prefixes = self._prefixes + [clause]
- return gen
+ self._prefixes = self._prefixes + [clause]
class Update(_ValuesBase):
def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
@@ -3450,16 +3426,15 @@ class Update(_ValuesBase):
self._whereclause = clone(self._whereclause)
self.parameters = self.parameters.copy()
+ @_generative
def where(self, whereclause):
"""return a new update() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- if s._whereclause is not None:
- s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- s._whereclause = _literal_as_text(whereclause)
- return s
+ self._whereclause = _literal_as_text(whereclause)
class Delete(_UpdateBase):
@@ -3479,16 +3454,15 @@ class Delete(_UpdateBase):
else:
return ()
+ @_generative
def where(self, whereclause):
"""return a new delete() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- if s._whereclause is not None:
- s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- s._whereclause = _literal_as_text(whereclause)
- return s
+ self._whereclause = _literal_as_text(whereclause)
def _copy_internals(self, clone=_clone):
self._whereclause = clone(self._whereclause)