diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-05-27 21:05:16 -0400 | 
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-05-27 21:05:16 -0400 | 
| commit | 0adcfea0d30376461c75cced87481f04a42481c0 (patch) | |
| tree | 9072ec2d98598591629dd9848cfa91ca62764f47 /lib/sqlalchemy | |
| parent | a9ed16f80d5e6d96d800004953b555b9cf1a592e (diff) | |
| download | sqlalchemy-0adcfea0d30376461c75cced87481f04a42481c0.tar.gz | |
still not locating more nested expressions, may need to match on name
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 48 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 83 | 
3 files changed, 121 insertions, 11 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8eb8c5fd8..5fbfa34f3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -390,15 +390,13 @@ class SQLCompiler(engine.Compiled):                              add_to_result_map=None,                              within_label_clause=False,                              within_columns_clause=False, -                            order_by_labels=None, **kw): +                            render_label_as_label=None, +                            **kw):          # only render labels within the columns clause          # or ORDER BY clause of a select.  dialect-specific compilers          # can modify this behavior. -#        if order_by_labels: -#            import pdb -#            pdb.set_trace()          render_label_with_as = within_columns_clause and not within_label_clause -        render_label_only = order_by_labels and label in order_by_labels +        render_label_only = render_label_as_label is label          if render_label_only or render_label_with_as:              if isinstance(label.name, sql._truncated_label): @@ -518,7 +516,9 @@ class SQLCompiler(engine.Compiled):      def visit_false(self, expr, **kw):          return 'false' -    def visit_clauselist(self, clauselist, **kwargs): +    def visit_clauselist(self, clauselist, order_by_select=None, **kw): +        if order_by_select is not None: +            return self._order_by_clauselist(clauselist, order_by_select, **kw)          sep = clauselist.operator          if sep is None:              sep = " " @@ -526,8 +526,34 @@ class SQLCompiler(engine.Compiled):              sep = OPERATORS[clauselist.operator]          return sep.join(                      s for s in -                    (c._compiler_dispatch(self, **kwargs) -                    for c in clauselist.clauses) +                    ( +                        c._compiler_dispatch(self, **kw) +                        for c in clauselist.clauses) +                    if s) + +    def _order_by_clauselist(self, clauselist, order_by_select, **kw): +        # look through raw columns collection for labels. +        # note that its OK we aren't expanding tables and other selectables +        # here; we can only add a label in the ORDER BY for an individual +        # label expression in the columns clause. +        raw_col = set(order_by_select._raw_columns) +        def label_ok(c): +            if c in raw_col: +                return c +            elif getattr(c, 'modifier', None) in \ +                    (operators.desc_op, operators.asc_op) and \ +                    c.element.proxy_set.intersection(raw_col): +                return c.element +            else: +                return None + +        return ", ".join( +                    s for s in +                    ( +                        c._compiler_dispatch(self, +                                render_label_as_label=label_ok(c), +                                **kw) +                        for c in clauselist.clauses)                      if s)      def visit_case(self, clause, **kwargs): @@ -1192,12 +1218,12 @@ class SQLCompiler(engine.Compiled):          if select._order_by_clause.clauses:              if self.dialect.supports_simple_order_by_label: -                order_by_labels = set(c for k, c in select._columns_plus_names) +                order_by_select = select              else: -                order_by_labels = None +                order_by_select = None              text += self.order_by_clause(select, -                                    order_by_labels=order_by_labels, **kwargs) +                            order_by_select=order_by_select, **kwargs)          if select._limit is not None or select._offset is not None:              text += self.limit_clause(select)          if select.for_update: diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index f65dd1a34..780aa40aa 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -2,6 +2,7 @@  from sqlalchemy.testing.suite.test_ddl import *  from sqlalchemy.testing.suite.test_insert import *  from sqlalchemy.testing.suite.test_sequence import * +from sqlalchemy.testing.suite.test_select import *  from sqlalchemy.testing.suite.test_results import *  from sqlalchemy.testing.suite.test_update_delete import *  from sqlalchemy.testing.suite.test_reflection import * diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py new file mode 100644 index 000000000..b040c8f25 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -0,0 +1,83 @@ +from .. import fixtures, config +from ..assertions import eq_ + +from sqlalchemy import Integer, String, select, func + +from ..schema import Table, Column + + +class OrderByLabelTest(fixtures.TablesTest): +    """Test the dialect sends appropriate ORDER BY expressions when +    labels are used. + +    This essentially exercises the "supports_simple_order_by_label" +    setting. + +    """ +    @classmethod +    def define_tables(cls, metadata): +        Table("some_table", metadata, +            Column('id', Integer, primary_key=True), +            Column('x', Integer), +            Column('y', Integer), +            Column('q', String(50)), +            Column('p', String(50)) +            ) + +    @classmethod +    def insert_data(cls): +        config.db.execute( +            cls.tables.some_table.insert(), +            [ +                {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, +                {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, +                {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, +            ] +        ) + +    def _assert_result(self, select, result): +        eq_( +            config.db.execute(select).fetchall(), +            result +        ) + +    def test_plain(self): +        table = self.tables.some_table +        lx = table.c.x.label('lx') +        self._assert_result( +            select([lx]).order_by(lx), +            [(1, ), (2, ), (3, )] +        ) + +    def test_composed_int(self): +        table = self.tables.some_table +        lx = (table.c.x + table.c.y).label('lx') +        self._assert_result( +            select([lx]).order_by(lx), +            [(3, ), (5, ), (7, )] +        ) + +    def test_composed_multiple(self): +        table = self.tables.some_table +        lx = (table.c.x + table.c.y).label('lx') +        ly = (func.lower(table.c.q) + table.c.p).label('ly') +        self._assert_result( +            select([lx, ly]).order_by(lx, ly.desc()), +            [(3, u'q1p3'), (5, u'q2p2'), (7, u'q3p1')] +        ) + +    def test_plain_desc(self): +        table = self.tables.some_table +        lx = table.c.x.label('lx') +        self._assert_result( +            select([lx]).order_by(lx.desc()), +            [(3, ), (2, ), (1, )] +        ) + +    def test_composed_int_desc(self): +        table = self.tables.some_table +        lx = (table.c.x + table.c.y).label('lx') +        self._assert_result( +            select([lx]).order_by(lx.desc()), +            [(7, ), (5, ), (3, )] +        )  | 
