diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2013-06-22 07:47:02 -0700 | 
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2013-06-22 07:47:02 -0700 | 
| commit | 29fa6913be46c4e4c95b2b2810baea24c4b211dd (patch) | |
| tree | 858b755e10ec1dd30235c9f96925f56fa4361544 /test/dialect/test_postgresql.py | |
| parent | 8c555f24b197832b9944f25d47d5989aa942bdea (diff) | |
| parent | b2da12e070e9d83bea5284dae11b8e6d4d509818 (diff) | |
| download | sqlalchemy-29fa6913be46c4e4c95b2b2810baea24c4b211dd.tar.gz | |
Merge pull request #5 from cjw296/pg-ranges
Support for Postgres range types.
Diffstat (limited to 'test/dialect/test_postgresql.py')
| -rw-r--r-- | test/dialect/test_postgresql.py | 330 | 
1 files changed, 329 insertions, 1 deletions
| diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index d1ba960c4..88554a34d 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -17,7 +17,9 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \  from sqlalchemy.orm import Session, mapper, aliased  from sqlalchemy import exc, schema, types  from sqlalchemy.dialects.postgresql import base as postgresql -from sqlalchemy.dialects.postgresql import HSTORE, hstore, array +from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \ +            INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \ +            ExcludeConstraint  import decimal  from sqlalchemy import util  from sqlalchemy.testing.util import round_decimal @@ -182,6 +184,53 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):                              'USING hash (data)',                              dialect=postgresql.dialect()) +    def test_exclude_constraint_min(self): +        m = MetaData() +        tbl = Table('testtbl', m,  +                    Column('room', Integer, primary_key=True)) +        cons = ExcludeConstraint(('room', '=')) +        tbl.append_constraint(cons) +        self.assert_compile(schema.AddConstraint(cons), +                            'ALTER TABLE testtbl ADD EXCLUDE USING gist ' +                            '(room WITH =)', +                            dialect=postgresql.dialect()) + +    def test_exclude_constraint_full(self): +        m = MetaData() +        room = Column('room', Integer, primary_key=True) +        tbl = Table('testtbl', m, +                    room, +                    Column('during', TSRANGE)) +        room = Column('room', Integer, primary_key=True) +        cons = ExcludeConstraint((room, '='), ('during', '&&'), +                                 name='my_name', +                                 using='gist', +                                 where="room > 100", +                                 deferrable=True, +                                 initially='immediate') +        tbl.append_constraint(cons) +        self.assert_compile(schema.AddConstraint(cons), +                            'ALTER TABLE testtbl ADD CONSTRAINT my_name ' +                            'EXCLUDE USING gist ' +                            '(room WITH =, during WITH ''&&) WHERE ' +                            '(room > 100) DEFERRABLE INITIALLY immediate', +                            dialect=postgresql.dialect()) + +    def test_exclude_constraint_copy(self): +        m = MetaData() +        cons = ExcludeConstraint(('room', '=')) +        tbl = Table('testtbl', m,  +              Column('room', Integer, primary_key=True), +              cons) +        # apparently you can't copy a ColumnCollectionConstraint until +        # after it has been bound to a table... +        cons_copy = cons.copy() +        tbl.append_constraint(cons_copy) +        self.assert_compile(schema.AddConstraint(cons_copy), +                            'ALTER TABLE testtbl ADD EXCLUDE USING gist ' +                            '(room WITH =)', +                            dialect=postgresql.dialect()) +      def test_substring(self):          self.assert_compile(func.substring('abc', 1, 2),                              'SUBSTRING(%(substring_1)s FROM %(substring_2)s ' @@ -3242,3 +3291,282 @@ class HStoreRoundTripTest(fixtures.TablesTest):      def test_unicode_round_trip_native(self):          engine = testing.db          self._test_unicode_round_trip(engine) + +class _RangeTypeMixin(object): +    __requires__ = 'range_types', +    __dialect__ = 'postgresql+psycopg2' + +    @property +    def extras(self): +        # done this way so we don't get ImportErrors with +        # older psycopg2 versions. +        from psycopg2 import extras +        return extras +     +    @classmethod +    def define_tables(cls, metadata): +        # no reason ranges shouldn't be primary keys, +        # so lets just use them as such +        table = Table('data_table', metadata, +            Column('range', cls._col_type, primary_key=True), +        ) +        cls.col = table.c.range + +    def test_actual_type(self): +        eq_(str(self._col_type()), self._col_str) +         +    def test_reflect(self): +        from sqlalchemy import inspect +        insp = inspect(testing.db) +        cols = insp.get_columns('data_table') +        assert isinstance(cols[0]['type'], self._col_type) + +    def _assert_data(self): +        data = testing.db.execute( +            select([self.tables.data_table.c.range]) +        ).fetchall() +        eq_(data, [(self._data_obj(), )]) + +    def test_insert_obj(self): +        testing.db.engine.execute( +            self.tables.data_table.insert(), +            {'range': self._data_obj()} +        ) +        self._assert_data() + +    def test_insert_text(self): +        testing.db.engine.execute( +            self.tables.data_table.insert(), +            {'range': self._data_str} +        ) +        self._assert_data() + +    # operator tests +         +    def _test_clause(self, colclause, expected): +        dialect = postgresql.dialect() +        compiled = str(colclause.compile(dialect=dialect)) +        eq_(compiled, expected) + +    def test_where_equal(self): +        self._test_clause( +            self.col==self._data_str, +            "data_table.range = %(range_1)s" +        ) + +    def test_where_not_equal(self): +        self._test_clause( +            self.col!=self._data_str, +            "data_table.range <> %(range_1)s" +        ) + +    def test_where_less_than(self): +        self._test_clause( +            self.col < self._data_str, +            "data_table.range < %(range_1)s" +        ) + +    def test_where_greater_than(self): +        self._test_clause( +            self.col > self._data_str, +            "data_table.range > %(range_1)s" +        ) + +    def test_where_less_than_or_equal(self): +        self._test_clause( +            self.col <= self._data_str, +            "data_table.range <= %(range_1)s" +        ) + +    def test_where_greater_than_or_equal(self): +        self._test_clause( +            self.col >= self._data_str, +            "data_table.range >= %(range_1)s" +        ) + +    def test_contains(self): +        self._test_clause( +            self.col.contains(self._data_str), +            "data_table.range @> %(range_1)s" +        ) + +    def test_contained_by(self): +        self._test_clause( +            self.col.contained_by(self._data_str), +            "data_table.range <@ %(range_1)s" +        ) + +    def test_overlaps(self): +        self._test_clause( +            self.col.overlaps(self._data_str), +            "data_table.range && %(range_1)s" +        ) + +    def test_strictly_left_of(self): +        self._test_clause( +            self.col << self._data_str, +            "data_table.range << %(range_1)s" +        ) +        self._test_clause( +            self.col.strictly_left_of(self._data_str), +            "data_table.range << %(range_1)s" +        ) + +    def test_strictly_right_of(self): +        self._test_clause( +            self.col >> self._data_str, +            "data_table.range >> %(range_1)s" +        ) +        self._test_clause( +            self.col.strictly_right_of(self._data_str), +            "data_table.range >> %(range_1)s" +        ) + +    def test_not_extend_right_of(self): +        self._test_clause( +            self.col.not_extend_right_of(self._data_str), +            "data_table.range &< %(range_1)s" +        ) + +    def test_not_extend_left_of(self): +        self._test_clause( +            self.col.not_extend_left_of(self._data_str), +            "data_table.range &> %(range_1)s" +        ) + +    def test_adjacent_to(self): +        self._test_clause( +            self.col.adjacent_to(self._data_str), +            "data_table.range -|- %(range_1)s" +        ) + +    def test_union(self): +        self._test_clause( +            self.col + self.col, +            "data_table.range + data_table.range" +        ) + +    def test_union_result(self): +        # insert +        testing.db.engine.execute( +            self.tables.data_table.insert(), +            {'range': self._data_str} +        ) +        # select +        range = self.tables.data_table.c.range +        data = testing.db.execute( +            select([range + range]) +            ).fetchall() +        eq_(data, [(self._data_obj(), )]) +         + +    def test_intersection(self): +        self._test_clause( +            self.col * self.col, +            "data_table.range * data_table.range" +        ) + +    def test_intersection_result(self): +        # insert +        testing.db.engine.execute( +            self.tables.data_table.insert(), +            {'range': self._data_str} +        ) +        # select +        range = self.tables.data_table.c.range +        data = testing.db.execute( +            select([range * range]) +            ).fetchall() +        eq_(data, [(self._data_obj(), )]) +         +    def test_different(self): +        self._test_clause( +            self.col - self.col, +            "data_table.range - data_table.range" +        ) + +    def test_difference_result(self): +        # insert +        testing.db.engine.execute( +            self.tables.data_table.insert(), +            {'range': self._data_str} +        ) +        # select +        range = self.tables.data_table.c.range +        data = testing.db.execute( +            select([range - range]) +            ).fetchall() +        eq_(data, [(self._data_obj().__class__(empty=True), )]) +         +class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest): + +    _col_type = INT4RANGE +    _col_str = 'INT4RANGE' +    _data_str = '[1,2)' +    def _data_obj(self): +        return self.extras.NumericRange(1, 2) + +class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest): + +    _col_type = INT8RANGE +    _col_str = 'INT8RANGE' +    _data_str = '[9223372036854775806,9223372036854775807)' +    def _data_obj(self): +        return self.extras.NumericRange( +            9223372036854775806, 9223372036854775807 +            ) + +class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest): + +    _col_type = NUMRANGE +    _col_str = 'NUMRANGE' +    _data_str = '[1.0,2.0)' +    def _data_obj(self): +        return self.extras.NumericRange( +            decimal.Decimal('1.0'), decimal.Decimal('2.0') +            ) + +class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest): + +    _col_type = DATERANGE +    _col_str = 'DATERANGE' +    _data_str = '[2013-03-23,2013-03-24)' +    def _data_obj(self): +        return self.extras.DateRange( +            datetime.date(2013, 3, 23), datetime.date(2013, 3, 24) +            ) + +class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest): + +    _col_type = TSRANGE +    _col_str = 'TSRANGE' +    _data_str = '[2013-03-23 14:30,2013-03-23 23:30)' +    def _data_obj(self): +        return self.extras.DateTimeRange( +            datetime.datetime(2013, 3, 23, 14, 30), +            datetime.datetime(2013, 3, 23, 23, 30) +            ) + +class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest): + +    _col_type = TSTZRANGE +    _col_str = 'TSTZRANGE' + +    # make sure we use one, steady timestamp with timezone pair +    # for all parts of all these tests +    _tstzs = None +    def tstzs(self): +        if self._tstzs is None: +            lower = testing.db.connect().scalar( +                func.current_timestamp().select() +                ) +            upper = lower+datetime.timedelta(1) +            self._tstzs = (lower, upper) +        return self._tstzs + +    @property +    def _data_str(self): +        return '[%s,%s)' % self.tstzs() +     +    def _data_obj(self): +        return self.extras.DateTimeTZRange(*self.tstzs()) | 
