diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-30 18:15:02 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-30 18:15:02 -0400 |
| commit | 00738b252c280111dafc8a034eade1507c1dddd8 (patch) | |
| tree | 84250759b0e653e7b72278b649ccc00ce3d074a7 /lib/sqlalchemy/sql | |
| parent | 62d6bf4cc33171ac21cd9b4d52701d6af39cfb42 (diff) | |
| parent | 4cbe117eb2feb7cff28c66d849d3a0613448fdce (diff) | |
| download | sqlalchemy-00738b252c280111dafc8a034eade1507c1dddd8.tar.gz | |
merge trunk. Re-instating topological._find_cycles for the moment
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 78 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 21 |
4 files changed, 108 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4e9175ae8..78c65771b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -305,11 +305,13 @@ class SQLCompiler(engine.Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + self.process(grouping.element, **kwargs) + ")" - def visit_label(self, label, result_map=None, within_columns_clause=False, **kw): + def visit_label(self, label, result_map=None, + within_label_clause=False, + within_columns_clause=False, **kw): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. - if within_columns_clause: + if within_columns_clause and not within_label_clause: labelname = isinstance(label.name, sql._generated_label) and \ self._truncated_identifier("colident", label.name) or label.name @@ -318,13 +320,14 @@ class SQLCompiler(engine.Compiled): (label.name, (label, label.element, labelname), label.element.type) return self.process(label.element, - within_columns_clause=within_columns_clause, + within_columns_clause=True, + within_label_clause=True, **kw) + \ OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: return self.process(label.element, - within_columns_clause=within_columns_clause, + within_columns_clause=False, **kw) def visit_column(self, column, result_map=None, **kwargs): @@ -625,13 +628,22 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} - def visit_alias(self, alias, asfrom=False, **kwargs): - if asfrom: + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): + if asfrom or ashint: alias_name = isinstance(alias.name, sql._generated_label) and \ self._truncated_identifier("alias", alias.name) or alias.name - - return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ + if ashint: + return self.preparer.format_alias(alias, alias_name) + elif asfrom: + ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ self.preparer.format_alias(alias, alias_name) + + if fromhints and alias in fromhints: + hinttext = self.get_from_hint_text(alias, fromhints[alias]) + if hinttext: + ret += " " + hinttext + + return ret else: return self.process(alias.original, **kwargs) @@ -658,8 +670,15 @@ class SQLCompiler(engine.Compiled): else: return column + def get_select_hint_text(self, byfroms): + return None + + def get_from_hint_text(self, table, text): + return None + def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, compound_index=1, **kwargs): + iswrapper=False, fromhints=None, + compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -694,6 +713,18 @@ class SQLCompiler(engine.Compiled): ] text = "SELECT " # we're off to a good start ! + + if select._hints: + byfrom = dict([ + (from_, hinttext % {'name':self.process(from_, ashint=True)}) + for (from_, dialect), hinttext in + select._hints.iteritems() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + if hint_text: + text += hint_text + " " + if select._prefixes: text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " text += self.get_select_precolumns(select) @@ -701,7 +732,16 @@ class SQLCompiler(engine.Compiled): if froms: text += " \nFROM " - text += ', '.join(self.process(f, asfrom=True, **kwargs) for f in froms) + + if select._hints: + text += ', '.join([self.process(f, + asfrom=True, fromhints=byfrom, + **kwargs) + for f in froms]) + else: + text += ', '.join([self.process(f, + asfrom=True, **kwargs) + for f in froms]) else: text += self.default_from() @@ -764,20 +804,26 @@ class SQLCompiler(engine.Compiled): text += " OFFSET " + str(select._offset) return text - def visit_table(self, table, asfrom=False, **kwargs): - if asfrom: + def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs): + if asfrom or ashint: if getattr(table, "schema", None): - return self.preparer.quote_schema(table.schema, table.quote_schema) + \ + ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \ "." + self.preparer.quote(table.name, table.quote) else: - return self.preparer.quote(table.name, table.quote) + ret = self.preparer.quote(table.name, table.quote) + if fromhints and table in fromhints: + hinttext = self.get_from_hint_text(table, fromhints[table]) + if hinttext: + ret += " " + hinttext + return ret else: return "" def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + \ + return (self.process(join.left, asfrom=True, **kwargs) + \ (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ - self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + self.process(join.right, asfrom=True, **kwargs) + " ON " + \ + self.process(join.onclause, **kwargs)) def visit_sequence(self, seq): return None diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 1e02ba96a..3aaa06fd6 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3557,6 +3557,7 @@ class Select(_SelectBaseMixin, FromClause): __visit_name__ = 'select' _prefixes = () + _hints = util.frozendict() def __init__(self, columns, @@ -3659,7 +3660,34 @@ class Select(_SelectBaseMixin, FromClause): """Return the displayed list of FromClause elements.""" return self._get_display_froms() - + + @_generative + def with_hint(self, selectable, text, dialect_name=None): + """Add an indexing hint for the given selectable to this :class:`Select`. + + The text of the hint is written specific to a specific backend, and + typically uses Python string substitution syntax to render the name + of the table or alias, such as for Oracle:: + + select([mytable]).with_hint(mytable, "+ index(%(name)s ix_mytable)") + + Would render SQL as:: + + select /*+ index(mytable ix_mytable) */ ... from mytable + + The ``dialect_name`` option will limit the rendering of a particular hint + to a particular backend. Such as, to add hints for both Oracle and + Sybase simultaneously:: + + select([mytable]).\ + with_hint(mytable, "+ index(%(name)s ix_mytable)", 'oracle').\ + with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + + """ + if not dialect_name: + dialect_name = '*' + self._hints = self._hints.union({(selectable, dialect_name):text}) + @property def type(self): raise exc.InvalidRequestError("Select objects don't have a type. " diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 74651a9d1..d5575e0e7 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -579,7 +579,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return None elif self.exclude and col in self.exclude: return None - + return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): @@ -587,11 +587,13 @@ class ColumnAdapter(ClauseAdapter): Provides the ability to "wrap" this ClauseAdapter around another, a columns dictionary which returns - cached, adapted elements given an original, and an + adapted elements given an original, and an adapted_row() factory. """ - def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None, adapt_required=False): + def __init__(self, selectable, equivalents=None, + chain_to=None, include=None, + exclude=None, adapt_required=False): ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) if chain_to: self.chain(chain_to) @@ -617,7 +619,7 @@ class ColumnAdapter(ClauseAdapter): return locate def _locate_col(self, col): - c = self._corresponding_column(col, False) + c = self._corresponding_column(col, True) if c is None: c = self.adapt_clause(col) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 4a54375f8..799486c02 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -40,16 +40,17 @@ class VisitableType(type): # set up an optimized visit dispatch function # for use by the compiler - visit_name = cls.__visit_name__ - if isinstance(visit_name, str): - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - return getter(visitor)(self, **kw) - else: - def _compiler_dispatch(self, visitor, **kw): - return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) - - cls._compiler_dispatch = _compiler_dispatch + if '__visit_name__' in cls.__dict__: + visit_name = cls.__visit_name__ + if isinstance(visit_name, str): + getter = operator.attrgetter("visit_%s" % visit_name) + def _compiler_dispatch(self, visitor, **kw): + return getter(visitor)(self, **kw) + else: + def _compiler_dispatch(self, visitor, **kw): + return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + + cls._compiler_dispatch = _compiler_dispatch super(VisitableType, cls).__init__(clsname, bases, clsdict) |
