diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 194 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 4 |
2 files changed, 97 insertions, 101 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 55001dc70..a448fa6d3 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -820,6 +820,12 @@ def _literal_as_binds(element, name=None, type_=None): else: return element +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column(column, require_embedded=require_embedded) + if not c: + raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + return c + def _selectable(element): if hasattr(element, '__selectable__'): return element.__selectable__() @@ -958,13 +964,8 @@ class ClauseElement(object): return False - def _find_engine(self): - """Default strategy for locating an engine within the clause element. - - Relies upon a local engine property, or looks in the *from* - objects which ultimately have to contain Tables or - TableClauses. - """ + def bind(self): + """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" try: if self._bind is not None: @@ -979,8 +980,7 @@ class ClauseElement(object): return engine else: return None - - bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""") + bind = property(bind) def execute(self, *multiparams, **params): """Compile and execute this ``ClauseElement``.""" @@ -1406,7 +1406,6 @@ class ColumnElement(ClauseElement, _CompareMixin): return self._base_columns self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) return self._base_columns - base_columns = property(base_columns) def proxy_set(self): @@ -1603,7 +1602,7 @@ class FromClause(Selectable): from sqlalchemy.sql import util return util.ClauseAdapter(alias).traverse(self, clone=True) - def corresponding_column(self, column, raiseerr=True, require_embedded=False): + def corresponding_column(self, column, require_embedded=False): """Given a ``ColumnElement``, return the exported ``ColumnElement`` object from this ``Selectable`` which corresponds to that original ``Column`` via a common anscestor column. @@ -1611,10 +1610,6 @@ class FromClause(Selectable): column the target ``ColumnElement`` to be matched - raiseerr - if True, raise an error if the given ``ColumnElement`` could - not be matched. if False, non-matches will return None. - require_embedded only return corresponding columns for the given ``ColumnElement``, if the given ``ColumnElement`` is @@ -1624,12 +1619,6 @@ class FromClause(Selectable): of this ``FromClause``. """ - if require_embedded and column not in self._get_all_embedded_columns(): - if not raiseerr: - return None - else: - raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table.description)) - # dont dig around if the column is locally present if self.c.contains_column(column): return column @@ -1638,16 +1627,12 @@ class FromClause(Selectable): target_set = column.proxy_set for c in self.c + [self.oid_column]: i = c.proxy_set.intersection(target_set) - if i and (intersect is None or len(i) > len(intersect)): + if i and \ + (not require_embedded or c.proxy_set.issuperset(target_set)) and \ + (intersect is None or len(i) > len(intersect)): col, intersect = c, i - if col: - return col - - if not raiseerr: - return None - else: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.description)) - + return col + def description(self): """a brief description of this FromClause. @@ -1666,17 +1651,6 @@ class FromClause(Selectable): if hasattr(self, attr): delattr(self, attr) - def _get_all_embedded_columns(self): - if hasattr(self, '_embedded_columns'): - return self._embedded_columns - ret = util.Set() - class FindCols(visitors.ClauseVisitor): - def visit_column(self, col): - ret.add(col) - FindCols().traverse(self) - self._embedded_columns = ret - return ret - def _expr_attr_func(name): def attr(self): try: @@ -1684,12 +1658,11 @@ class FromClause(Selectable): except AttributeError: self._export_columns() return getattr(self, name) - return attr + return property(attr) - columns = property(_expr_attr_func('_columns')) - c = property(_expr_attr_func('_columns')) - primary_key = property(_expr_attr_func('_primary_key')) - foreign_keys = property(_expr_attr_func('_foreign_keys')) + columns = c = _expr_attr_func('_columns') + primary_key = _expr_attr_func('_primary_key') + foreign_keys = _expr_attr_func('_foreign_keys') def _export_columns(self, columns=None): """Initialize column collections.""" @@ -1881,14 +1854,14 @@ class _TextClause(ClauseElement): for b in bindparams: self.bindparams[b.key] = b - def _get_type(self): + def type(self): if self.typemap is not None and len(self.typemap) == 1: return list(self.typemap)[0] else: return None - type = property(_get_type) + type = property(type) - columns = property(lambda s:[]) + columns = [] def _copy_internals(self, clone=_clone): self.bindparams = dict([(b.key, clone(b)) for b in self.bindparams.values()]) @@ -2329,7 +2302,12 @@ class Join(FromClause): else: return and_(*crit) - def _get_folded_equivalents(self, equivs=None): + def _folded_equivalents(self, equivs=None): + """Returns the column list of this Join with all equivalently-named, + equated columns folded into one column, where 'equated' means they are + equated to each other in the ON clause of this join. + """ + if self.__folded_equivalents is not None: return self.__folded_equivalents if equivs is None: @@ -2342,11 +2320,11 @@ class Join(FromClause): LocateEquivs().traverse(self.onclause) collist = [] if isinstance(self.left, Join): - left = self.left._get_folded_equivalents(equivs) + left = self.left._folded_equivalents(equivs) else: left = list(self.left.columns) if isinstance(self.right, Join): - right = self.right._get_folded_equivalents(equivs) + right = self.right._folded_equivalents(equivs) else: right = list(self.right.columns) used = util.Set() @@ -2359,10 +2337,7 @@ class Join(FromClause): collist.append(c) self.__folded_equivalents = collist return self.__folded_equivalents - - folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, " - "equated columns folded into one column, where 'equated' means they are " - "equated to each other in the ON clause of this join.") + folded_equivalents = property(_folded_equivalents) def select(self, whereclause = None, fold_equivalents=False, **kwargs): """Create a ``Select`` from this ``Join``. @@ -2391,7 +2366,9 @@ class Join(FromClause): return select(collist, whereclause, from_obj=[self], **kwargs) - bind = property(lambda s:s.left.bind or s.right.bind) + def bind(self): + return self.left.bind or self.right.bind + bind = property(bind) def alias(self, name=None): """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it. @@ -2474,8 +2451,10 @@ class Alias(FromClause): def _get_from_objects(self, **modifiers): return [self] - - bind = property(lambda s: s.selectable.bind) + + def bind(self): + return self.selectable.bind + bind = property(bind) class _ColumnElementAdapter(ColumnElement): """Adapts a ClauseElement which may or may not be a @@ -2486,9 +2465,14 @@ class _ColumnElementAdapter(ColumnElement): def __init__(self, elem): self.elem = elem self.type = getattr(elem, 'type', None) - - key = property(lambda s: s.elem.key) - _label = property(lambda s: s.elem._label) + + def key(self): + return self.elem.key + key = property(key) + + def _label(self): + return self.elem._label + _label = property(_label) def _copy_internals(self, clone=_clone): self.elem = clone(self.elem) @@ -2520,8 +2504,13 @@ class _FromGrouping(FromClause): def __init__(self, elem): self.elem = elem - columns = c = property(lambda s:s.elem.columns) - _hide_froms = property(lambda s:s.elem._hide_froms) + def columns(self): + return self.elem.columns + columns = c = property(columns) + + def _hide_froms(self): + return self.elem._hide_froms + _hide_froms = property(_hide_froms) def get_children(self, **kwargs): return self.elem, @@ -2553,23 +2542,34 @@ class _Label(ColumnElement): self.obj = obj.self_group(against=operators.as_) self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) - key = property(lambda s: s.name) - _label = property(lambda s: s.name) - proxies = property(lambda s:s.obj.proxies) - base_columns = property(lambda s:s.obj.base_columns) - proxy_set = property(lambda s:s.obj.proxy_set) - primary_key = property(lambda s:s.obj.primary_key) - foreign_keys = property(lambda s:s.obj.foreign_keys) + def key(self): + return self.name + key = property(key) + + def _label(self): + return self.name + _label = property(_label) + + def _proxy_attr(name): + def attr(self): + return getattr(self.obj, name) + return property(attr) + + proxies = _proxy_attr('proxies') + base_columns = _proxy_attr('base_columns') + proxy_set = _proxy_attr('proxy_set') + primary_key = _proxy_attr('primary_key') + foreign_keys = _proxy_attr('foreign_keys') def expression_element(self): return self.obj - def _copy_internals(self, clone=_clone): - self.obj = clone(self.obj) - def get_children(self, **kwargs): return self.obj, + def _copy_internals(self, clone=_clone): + self.obj = clone(self.obj) + def _get_from_objects(self, **modifiers): return self.obj._get_from_objects(**modifiers) @@ -2623,13 +2623,8 @@ class _ColumnClause(ColumnElement): # ColumnClause is immutable return self - def _get_label(self): - """Generate a 'label' for this column. - - The label is a product of the parent table name and column - name, and is treated as a unique identifier of this ``Column`` - across all ``Tables`` and derived selectables for a particular - metadata collection. + def _label(self): + """Generate a 'label' string for this column. """ # for a "literal" column, we've no idea what the text is @@ -2647,7 +2642,7 @@ class _ColumnClause(ColumnElement): self.__label = self.name return self.__label - _label = property(_get_label) + _label = property(_label) def label(self, name): # if going off the "__label" property and its None, we have @@ -2903,10 +2898,9 @@ class _ScalarSelect(_Grouping): raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") self.type = cols[0].type - def _no_cols(self): + def columns(self): raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") - c = property(_no_cols) - columns = c + columns = c = property(columns) def self_group(self, **kwargs): return self @@ -2979,14 +2973,15 @@ class CompoundSelect(_SelectBaseMixin, FromClause): for t in s._table_iterator(): yield t - def _find_engine(self): + def bind(self): for s in self.selects: - e = s._find_engine() + e = s.bind if e: return e else: return None - + bind = property(bind) + class Select(_SelectBaseMixin, FromClause): """Represents a ``SELECT`` statement. @@ -3115,15 +3110,18 @@ class Select(_SelectBaseMixin, FromClause): self._all_froms = froms return froms - def _get_inner_columns(self): + def inner_columns(self): + """a collection of all ColumnElement expressions which would + be rendered into the columns clause of the resulting SELECT statement. + """ + for c in self._raw_columns: if isinstance(c, Selectable): for co in c.columns: yield co else: yield c - - inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""") + inner_columns = property(inner_columns) def is_derived_from(self, fromclause): if self in util.Set(fromclause._cloned_set): @@ -3412,11 +3410,7 @@ class Select(_SelectBaseMixin, FromClause): if isinstance(t, TableClause): yield t - def _find_engine(self): - """Try to return a Engine, either explicitly set in this - object, or searched within the from clauses for one. - """ - + def bind(self): if self._bind is not None: return self._bind for f in self._froms: @@ -3436,7 +3430,8 @@ class Select(_SelectBaseMixin, FromClause): self._bind = e return e return None - + bind = property(bind) + class _UpdateBase(ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" @@ -3459,9 +3454,10 @@ class _UpdateBase(ClauseElement): else: return parameters - def _find_engine(self): + def bind(self): return self.table.bind - + bind = property(bind) + class Insert(_UpdateBase): def __init__(self, table, values=None, inline=False, **kwargs): self.table = table diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d6b10a78a..b45c0425c 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -198,10 +198,10 @@ class ClauseAdapter(AbstractClauseProcessor): if self.exclude is not None: if col in self.exclude: return None - newcol = self.selectable.corresponding_column(col, raiseerr=False, require_embedded=True) + newcol = self.selectable.corresponding_column(col, require_embedded=True) if newcol is None and self.equivalents is not None and col in self.equivalents: for equiv in self.equivalents[col]: - newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True) + newcol = self.selectable.corresponding_column(equiv, require_embedded=True) if newcol: return newcol return newcol |
