diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 56 | 
1 files changed, 33 insertions, 23 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bc75621e9..f9331a73e 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -44,12 +44,12 @@ def emits_warning(*messages):                          category=sa_exc.SAPendingDeprecationWarning)]          if not messages:              filters.append(dict(action='ignore', -                                 category=sa_exc.SAWarning)) +                                category=sa_exc.SAWarning))          else:              filters.extend(dict(action='ignore', -                                 message=message, -                                 category=sa_exc.SAWarning) -                            for message in messages) +                                message=message, +                                category=sa_exc.SAWarning) +                           for message in messages)          for f in filters:              warnings.filterwarnings(**f)          try: @@ -103,6 +103,7 @@ def uses_deprecated(*messages):              return fn(*args, **kw)      return decorate +  @contextlib.contextmanager  def expect_deprecated(*messages):      # todo: should probably be strict about this, too @@ -118,8 +119,8 @@ def expect_deprecated(*messages):                    category=sa_exc.SADeprecationWarning)               for message in               [(m.startswith('//') and -                ('Call to deprecated function ' + m[2:]) or m) -               for m in messages]]) +               ('Call to deprecated function ' + m[2:]) or m) +              for m in messages]])      for f in filters:          warnings.filterwarnings(**f) @@ -140,6 +141,8 @@ def global_cleanup_assertions():      _assert_no_stray_pool_connections()  _STRAY_CONNECTION_FAILURES = 0 + +  def _assert_no_stray_pool_connections():      global _STRAY_CONNECTION_FAILURES @@ -156,7 +159,7 @@ def _assert_no_stray_pool_connections():          _STRAY_CONNECTION_FAILURES += 1          print("Encountered a stray connection in test cleanup: %s" -                        % str(pool._refs)) +              % str(pool._refs))          # then do a real GC sweep.   We shouldn't even be here          # so a single sweep should really be doing it, otherwise          # there's probably a real unreachable cycle somewhere. @@ -218,17 +221,18 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):          callable_(*args, **kwargs)          assert False, "Callable did not raise an exception"      except except_cls as e: -        assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) +        assert re.search( +            msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)          print(util.text_type(e).encode('utf-8'))  class AssertsCompiledSQL(object):      def assert_compile(self, clause, result, params=None, -                        checkparams=None, dialect=None, -                        checkpositional=None, -                        use_default_dialect=False, -                        allow_dialect_select=False, -                        literal_binds=False): +                       checkparams=None, dialect=None, +                       checkpositional=None, +                       use_default_dialect=False, +                       allow_dialect_select=False, +                       literal_binds=False):          if use_default_dialect:              dialect = default.DefaultDialect()          elif allow_dialect_select: @@ -244,7 +248,6 @@ class AssertsCompiledSQL(object):              elif isinstance(dialect, util.string_types):                  dialect = url.URL(dialect).get_dialect()() -          kw = {}          compile_kwargs = {} @@ -268,10 +271,15 @@ class AssertsCompiledSQL(object):          if util.py3k:              param_str = param_str.encode('utf-8').decode('ascii', 'ignore') -            print(("\nSQL String:\n" + util.text_type(c) + param_str).encode('utf-8')) +            print( +                ("\nSQL String:\n" + +                 util.text_type(c) + +                 param_str).encode('utf-8'))          else: -            print("\nSQL String:\n" + util.text_type(c).encode('utf-8') + param_str) - +            print( +                "\nSQL String:\n" + +                util.text_type(c).encode('utf-8') + +                param_str)          cc = re.sub(r'[\n\t]', '', util.text_type(c)) @@ -296,7 +304,7 @@ class ComparesTables(object):              if strict_types:                  msg = "Type '%s' doesn't correspond to type '%s'" -                assert type(reflected_c.type) is type(c.type), \ +                assert isinstance(reflected_c.type, type(c.type)), \                      msg % (reflected_c.type, c.type)              else:                  self.assert_types_base(reflected_c, c) @@ -318,8 +326,8 @@ class ComparesTables(object):      def assert_types_base(self, c1, c2):          assert c1.type._compare_type_affinity(c2.type),\ -                "On column %r, type '%s' doesn't correspond to type '%s'" % \ -                (c1.name, c1.type, c2.type) +            "On column %r, type '%s' doesn't correspond to type '%s'" % \ +            (c1.name, c1.type, c2.type)  class AssertsExecutionResults(object): @@ -363,7 +371,8 @@ class AssertsExecutionResults(object):          found = util.IdentitySet(result)          expected = set([immutabledict(e) for e in expected]) -        for wrong in util.itertools_filterfalse(lambda o: type(o) == cls, found): +        for wrong in util.itertools_filterfalse(lambda o: +                                                isinstance(o, cls), found):              fail('Unexpected type "%s", expected "%s"' % (                  type(wrong).__name__, cls.__name__)) @@ -394,7 +403,7 @@ class AssertsExecutionResults(object):              else:                  fail(                      "Expected %s instance with attributes %s not found." % ( -                    cls.__name__, repr(expected_item))) +                        cls.__name__, repr(expected_item)))          return True      def assert_sql_execution(self, db, callable_, *rules): @@ -406,7 +415,8 @@ class AssertsExecutionResults(object):              assertsql.asserter.clear_rules()      def assert_sql(self, db, callable_, list_, with_sequences=None): -        if with_sequences is not None and config.db.dialect.supports_sequences: +        if (with_sequences is not None and +                config.db.dialect.supports_sequences):              rules = with_sequences          else:              rules = list_  | 
