diff options
| author | Brian Jarrett <celttechie@gmail.com> | 2014-07-20 12:44:40 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-07-20 12:44:40 -0400 |
| commit | cca03097f47f22783d42d1853faac6cf84607c5a (patch) | |
| tree | 4fe1a63d03a2d88d1cf37e1167759dfaf84f4ce7 /lib/sqlalchemy/testing | |
| parent | 827329a0cca5351094a1a86b6b2be2b9182f0ae2 (diff) | |
| download | sqlalchemy-cca03097f47f22783d42d1853faac6cf84607c5a.tar.gz | |
- apply pep8 formatting to sqlalchemy/sql, sqlalchemy/util, sqlalchemy/dialects,
sqlalchemy/orm, sqlalchemy/event, sqlalchemy/testing
Diffstat (limited to 'lib/sqlalchemy/testing')
28 files changed, 601 insertions, 533 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 453f2329f..8f8f56412 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -19,9 +19,9 @@ def against(*queries): return _against(config._current, *queries) from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ - eq_, ne_, is_, is_not_, startswith_, assert_raises, \ - assert_raises_message, AssertsCompiledSQL, ComparesTables, \ - AssertsExecutionResults, expect_deprecated + eq_, ne_, is_, is_not_, startswith_, assert_raises, \ + assert_raises_message, AssertsCompiledSQL, ComparesTables, \ + AssertsExecutionResults, expect_deprecated from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict @@ -30,4 +30,4 @@ crashes = skip from .config import db from .config import requirements as requires -from . import mock
\ No newline at end of file +from . import mock diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bc75621e9..f9331a73e 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -44,12 +44,12 @@ def emits_warning(*messages): category=sa_exc.SAPendingDeprecationWarning)] if not messages: filters.append(dict(action='ignore', - category=sa_exc.SAWarning)) + category=sa_exc.SAWarning)) else: filters.extend(dict(action='ignore', - message=message, - category=sa_exc.SAWarning) - for message in messages) + message=message, + category=sa_exc.SAWarning) + for message in messages) for f in filters: warnings.filterwarnings(**f) try: @@ -103,6 +103,7 @@ def uses_deprecated(*messages): return fn(*args, **kw) return decorate + @contextlib.contextmanager def expect_deprecated(*messages): # todo: should probably be strict about this, too @@ -118,8 +119,8 @@ def expect_deprecated(*messages): category=sa_exc.SADeprecationWarning) for message in [(m.startswith('//') and - ('Call to deprecated function ' + m[2:]) or m) - for m in messages]]) + ('Call to deprecated function ' + m[2:]) or m) + for m in messages]]) for f in filters: warnings.filterwarnings(**f) @@ -140,6 +141,8 @@ def global_cleanup_assertions(): _assert_no_stray_pool_connections() _STRAY_CONNECTION_FAILURES = 0 + + def _assert_no_stray_pool_connections(): global _STRAY_CONNECTION_FAILURES @@ -156,7 +159,7 @@ def _assert_no_stray_pool_connections(): _STRAY_CONNECTION_FAILURES += 1 print("Encountered a stray connection in test cleanup: %s" - % str(pool._refs)) + % str(pool._refs)) # then do a real GC sweep. We shouldn't even be here # so a single sweep should really be doing it, otherwise # there's probably a real unreachable cycle somewhere. @@ -218,17 +221,18 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls as e: - assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) + assert re.search( + msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) print(util.text_type(e).encode('utf-8')) class AssertsCompiledSQL(object): def assert_compile(self, clause, result, params=None, - checkparams=None, dialect=None, - checkpositional=None, - use_default_dialect=False, - allow_dialect_select=False, - literal_binds=False): + checkparams=None, dialect=None, + checkpositional=None, + use_default_dialect=False, + allow_dialect_select=False, + literal_binds=False): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: @@ -244,7 +248,6 @@ class AssertsCompiledSQL(object): elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() - kw = {} compile_kwargs = {} @@ -268,10 +271,15 @@ class AssertsCompiledSQL(object): if util.py3k: param_str = param_str.encode('utf-8').decode('ascii', 'ignore') - print(("\nSQL String:\n" + util.text_type(c) + param_str).encode('utf-8')) + print( + ("\nSQL String:\n" + + util.text_type(c) + + param_str).encode('utf-8')) else: - print("\nSQL String:\n" + util.text_type(c).encode('utf-8') + param_str) - + print( + "\nSQL String:\n" + + util.text_type(c).encode('utf-8') + + param_str) cc = re.sub(r'[\n\t]', '', util.text_type(c)) @@ -296,7 +304,7 @@ class ComparesTables(object): if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" - assert type(reflected_c.type) is type(c.type), \ + assert isinstance(reflected_c.type, type(c.type)), \ msg % (reflected_c.type, c.type) else: self.assert_types_base(reflected_c, c) @@ -318,8 +326,8 @@ class ComparesTables(object): def assert_types_base(self, c1, c2): assert c1.type._compare_type_affinity(c2.type),\ - "On column %r, type '%s' doesn't correspond to type '%s'" % \ - (c1.name, c1.type, c2.type) + "On column %r, type '%s' doesn't correspond to type '%s'" % \ + (c1.name, c1.type, c2.type) class AssertsExecutionResults(object): @@ -363,7 +371,8 @@ class AssertsExecutionResults(object): found = util.IdentitySet(result) expected = set([immutabledict(e) for e in expected]) - for wrong in util.itertools_filterfalse(lambda o: type(o) == cls, found): + for wrong in util.itertools_filterfalse(lambda o: + isinstance(o, cls), found): fail('Unexpected type "%s", expected "%s"' % ( type(wrong).__name__, cls.__name__)) @@ -394,7 +403,7 @@ class AssertsExecutionResults(object): else: fail( "Expected %s instance with attributes %s not found." % ( - cls.__name__, repr(expected_item))) + cls.__name__, repr(expected_item))) return True def assert_sql_execution(self, db, callable_, *rules): @@ -406,7 +415,8 @@ class AssertsExecutionResults(object): assertsql.asserter.clear_rules() def assert_sql(self, db, callable_, list_, with_sequences=None): - if with_sequences is not None and config.db.dialect.supports_sequences: + if (with_sequences is not None and + config.db.dialect.supports_sequences): rules = with_sequences else: rules = list_ diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 7b4630c05..bcc999fe3 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -78,7 +78,7 @@ class ExactSQL(SQLMatchRule): return _received_statement = \ _process_engine_statement(context.unicode_statement, - context) + context) _received_parameters = context.compiled_parameters # TODO: remove this step once all unit tests are migrated, as @@ -99,10 +99,10 @@ class ExactSQL(SQLMatchRule): params = {} self._result = equivalent if not self._result: - self._errmsg = \ - 'Testing for exact statement %r exact params %r, '\ - 'received %r with params %r' % (sql, params, - _received_statement, _received_parameters) + self._errmsg = ( + 'Testing for exact statement %r exact params %r, ' + 'received %r with params %r' % + (sql, params, _received_statement, _received_parameters)) class RegexSQL(SQLMatchRule): @@ -119,7 +119,7 @@ class RegexSQL(SQLMatchRule): return _received_statement = \ _process_engine_statement(context.unicode_statement, - context) + context) _received_parameters = context.compiled_parameters equivalent = bool(self.regex.match(_received_statement)) if self.params: @@ -168,9 +168,11 @@ class CompiledSQL(SQLMatchRule): compiled = \ context.compiled.statement.compile(dialect=DefaultDialect()) else: - compiled = \ - context.compiled.statement.compile(dialect=DefaultDialect(), - column_keys=context.compiled.column_keys) + compiled = ( + context.compiled.statement.compile( + dialect=DefaultDialect(), + column_keys=context.compiled.column_keys) + ) _received_statement = re.sub(r'[\n\t]', '', str(compiled)) equivalent = self.statement == _received_statement if self.params: @@ -201,17 +203,19 @@ class CompiledSQL(SQLMatchRule): all_received = [] self._result = equivalent if not self._result: - print('Testing for compiled statement %r partial params '\ - '%r, received %r with params %r' % (self.statement, - all_params, _received_statement, all_received)) - self._errmsg = \ - 'Testing for compiled statement %r partial params %r, '\ - 'received %r with params %r' % (self.statement, - all_params, _received_statement, all_received) - + print('Testing for compiled statement %r partial params ' + '%r, received %r with params %r' % + (self.statement, all_params, + _received_statement, all_received)) + self._errmsg = ( + 'Testing for compiled statement %r partial params %r, ' + 'received %r with params %r' % + (self.statement, all_params, + _received_statement, all_received)) # print self._errmsg + class CountStatements(AssertRule): def __init__(self, count): @@ -248,7 +252,7 @@ class AllOf(AssertRule): executemany): for rule in self.rules: rule.process_cursor_execute(statement, parameters, context, - executemany) + executemany) def is_consumed(self): if not self.rules: @@ -265,6 +269,7 @@ class AllOf(AssertRule): def consume_final(self): return len(self.rules) == 0 + class Or(AllOf): def __init__(self, *rules): self.rules = set(rules) @@ -282,6 +287,7 @@ class Or(AllOf): def consume_final(self): assert self._consume_final, "Unsatisified rules remain" + def _process_engine_statement(query, context): if util.jython: @@ -289,7 +295,7 @@ def _process_engine_statement(query, context): query = str(query) if context.engine.name == 'mssql' \ - and query.endswith('; select scope_identity()'): + and query.endswith('; select scope_identity()'): query = query[:-25] query = re.sub(r'\n', '', query) return query @@ -348,6 +354,6 @@ class SQLAssert(object): if self.rules: rule = self.rules[0] rule.process_cursor_execute(statement, parameters, context, - executemany) + executemany) asserter = SQLAssert() diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 66bfbc892..c914434b4 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -15,6 +15,7 @@ file_config = None _current = None + class Config(object): def __init__(self, db, db_opts, options, file_config): self.db = db @@ -52,7 +53,8 @@ class Config(object): def push_engine(cls, db, namespace): assert _current, "Can't push without a default Config set up" cls.push( - Config(db, _current.db_opts, _current.options, _current.file_config), + Config( + db, _current.db_opts, _current.options, _current.file_config), namespace ) diff --git a/lib/sqlalchemy/testing/distutils_run.py b/lib/sqlalchemy/testing/distutils_run.py index d8f8f5931..ecec3ffd5 100644 --- a/lib/sqlalchemy/testing/distutils_run.py +++ b/lib/sqlalchemy/testing/distutils_run.py @@ -5,6 +5,7 @@ custom setuptools/distutils code. import unittest import pytest + class TestSuite(unittest.TestCase): def test_sqlalchemy(self): pytest.main() diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 4136e5292..9052df570 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -17,6 +17,7 @@ import re import warnings from .. import util + class ConnectionKiller(object): def __init__(self): @@ -43,8 +44,8 @@ class ConnectionKiller(object): raise except Exception as e: warnings.warn( - "testing_reaper couldn't " - "rollback/close connection: %s" % e) + "testing_reaper couldn't " + "rollback/close connection: %s" % e) def rollback_all(self): for rec in list(self.proxy_refs): @@ -174,8 +175,8 @@ class ReconnectFixture(object): raise except Exception as e: warnings.warn( - "ReconnectFixture couldn't " - "close connection: %s" % e) + "ReconnectFixture couldn't " + "close connection: %s" % e) def shutdown(self): # TODO: this doesn't cover all cases @@ -236,8 +237,6 @@ def testing_engine(url=None, options=None): return engine - - def mock_engine(dialect_name=None): """Provides a mocking engine based on the current testing.db. @@ -262,7 +261,7 @@ def mock_engine(dialect_name=None): def assert_sql(stmts): recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] - assert recv == stmts, recv + assert recv == stmts, recv def print_sql(): d = engine.dialect @@ -287,6 +286,7 @@ class DBAPIProxyCursor(object): DBAPI-level cursor operations. """ + def __init__(self, engine, conn): self.engine = engine self.connection = conn @@ -312,6 +312,7 @@ class DBAPIProxyConnection(object): DBAPI-level connection operations. """ + def __init__(self, engine, cursor_cls): self.conn = self._sqla_unwrap = engine.pool._creator() self.engine = engine @@ -352,20 +353,20 @@ class ReplayableSession(object): if util.py2k: Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', 'UnboundMethodType',)]) + for t in dir(types) if not t.startswith('_')]).\ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', 'UnboundMethodType',)]) else: Natives = set([getattr(types, t) for t in dir(types) if not t.startswith('_')]).\ - union([type(t) if not isinstance(t, type) - else t for t in __builtins__.values()]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', )]) + union([type(t) if not isinstance(t, type) + else t for t in __builtins__.values()]).\ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', )]) def __init__(self): self.buffer = deque() diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index 7bf99918a..3e42955e6 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -86,7 +86,8 @@ class ComparableEntity(BasicEntity): return False if hasattr(value, '__iter__'): - if hasattr(value, '__getitem__') and not hasattr(value, 'keys'): + if hasattr(value, '__getitem__') and not hasattr( + value, 'keys'): if list(value) != list(battr): return False else: diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 41337ea4d..fd43865aa 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -14,6 +14,7 @@ from .. import util import contextlib import inspect + class skip_if(object): def __init__(self, predicate, reason=None): self.predicate = _as_predicate(predicate) @@ -55,13 +56,13 @@ class skip_if(object): if self.predicate(config._current): if self.reason: msg = "'%s' : %s" % ( - fn.__name__, - self.reason - ) + fn.__name__, + self.reason + ) else: msg = "'%s': %s" % ( - fn.__name__, self.predicate - ) + fn.__name__, self.predicate + ) raise SkipTest(msg) else: if self._fails_on: @@ -79,6 +80,7 @@ class skip_if(object): self._fails_on = skip_if(fails_on_everything_except(*dbs)) return self + class fails_if(skip_if): def __call__(self, fn): @decorator @@ -150,15 +152,15 @@ class SpecPredicate(Predicate): self.description = description _ops = { - '<': operator.lt, - '>': operator.gt, - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '>=': operator.ge, - 'in': operator.contains, - 'between': lambda val, pair: val >= pair[0] and val <= pair[1], - } + '<': operator.lt, + '>': operator.gt, + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '>=': operator.ge, + 'in': operator.contains, + 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + } def __call__(self, config): engine = config.db @@ -178,7 +180,7 @@ class SpecPredicate(Predicate): version = _server_version(engine) oper = hasattr(self.op, '__call__') and self.op \ - or self._ops[self.op] + or self._ops[self.op] return oper(version, self.spec) else: return True @@ -194,16 +196,16 @@ class SpecPredicate(Predicate): else: if negate: return "not %s %s %s" % ( - self.db, - self.op, - self.spec - ) + self.db, + self.op, + self.spec + ) else: return "%s %s %s" % ( - self.db, - self.op, - self.spec - ) + self.db, + self.op, + self.spec + ) def __str__(self): return self._as_string() @@ -270,7 +272,7 @@ class OrPredicate(Predicate): else: conjunction = " or " return conjunction.join(p._as_string(negate=negate) - for p in self.predicates) + for p in self.predicates) else: return self._str._as_string(negate=negate) @@ -311,8 +313,8 @@ def _server_version(engine): def db_spec(*dbs): return OrPredicate( - [Predicate.as_predicate(db) for db in dbs] - ) + [Predicate.as_predicate(db) for db in dbs] + ) def open(): @@ -322,9 +324,11 @@ def open(): def closed(): return skip_if(BooleanPredicate(True, "marked as skip")) + def fails(): return fails_if(BooleanPredicate(True, "expected to fail")) + @decorator def future(fn, *arg): return fails_if(LambdaPredicate(fn), "Future feature") @@ -336,10 +340,10 @@ def fails_on(db, reason=None): def fails_on_everything_except(*dbs): return succeeds_if( - OrPredicate([ + OrPredicate([ SpecPredicate(db) for db in dbs ]) - ) + ) def skip(db, reason=None): @@ -348,7 +352,7 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) + OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) ) @@ -359,6 +363,6 @@ def exclude(db, op, spec, reason=None): def against(config, *queries): assert queries, "no queries sent!" return OrPredicate([ - Predicate.as_predicate(query) - for query in queries - ])(config) + Predicate.as_predicate(query) + for query in queries + ])(config) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 23d010ec9..7c7b00998 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -18,6 +18,7 @@ from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta # whether or not we use unittest changes things dramatically, # as far as how py.test collection works. + class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. @@ -49,6 +50,7 @@ class TestBase(object): if hasattr(self, "tearDown"): self.tearDown() + class TablesTest(TestBase): # 'once', None @@ -222,6 +224,8 @@ class TablesTest(TestBase): for column_values in rows[table]]) from sqlalchemy import event + + class RemovesEvents(object): @util.memoized_property def _event_fns(self): @@ -239,7 +243,6 @@ class RemovesEvents(object): super_.teardown() - class _ORMTest(object): @classmethod @@ -366,14 +369,14 @@ class DeclarativeMappedTest(MappedTest): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls return DeclarativeMeta.__init__( - cls, classname, bases, dict_) + cls, classname, bases, dict_) class DeclarativeBasic(object): __table_cls__ = schema.Table _DeclBase = declarative_base(metadata=cls.metadata, - metaclass=FindFixtureDeclarative, - cls=DeclarativeBasic) + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic) cls.DeclarativeBasic = _DeclBase fn() diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py index ccbe8aa92..c6a4d4360 100644 --- a/lib/sqlalchemy/testing/mock.py +++ b/lib/sqlalchemy/testing/mock.py @@ -17,6 +17,5 @@ else: from mock import MagicMock, Mock, call, patch except ImportError: raise ImportError( - "SQLAlchemy's test suite requires the " - "'mock' library as of 0.8.2.") - + "SQLAlchemy's test suite requires the " + "'mock' library as of 0.8.2.") diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index fe68457e8..5a903aae7 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -63,8 +63,8 @@ class Foo(object): def __eq__(self, other): return other.data == self.data and \ - other.stuff == self.stuff and \ - other.moredata == self.moredata + other.stuff == self.stuff and \ + other.moredata == self.moredata class Bar(object): diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py index 7262adb4b..e362d6141 100644 --- a/lib/sqlalchemy/testing/plugin/noseplugin.py +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -21,9 +21,10 @@ fixtures = None # no package imports yet! this prevents us from tripping coverage # too soon. path = os.path.join(os.path.dirname(__file__), "plugin_base.py") -if sys.version_info >= (3,3): +if sys.version_info >= (3, 3): from importlib import machinery - plugin_base = machinery.SourceFileLoader("plugin_base", path).load_module() + plugin_base = machinery.SourceFileLoader( + "plugin_base", path).load_module() else: import imp plugin_base = imp.load_source("plugin_base", path) @@ -76,20 +77,20 @@ class NoseSQLAlchemy(Plugin): def beforeTest(self, test): plugin_base.before_test(test, - test.test.cls.__module__, - test.test.cls, test.test.method.__name__) + test.test.cls.__module__, + test.test.cls, test.test.method.__name__) def afterTest(self, test): plugin_base.after_test(test) def startContext(self, ctx): if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + or not issubclass(ctx, fixtures.TestBase): return plugin_base.start_test_class(ctx) def stopContext(self, ctx): if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + or not issubclass(ctx, fixtures.TestBase): return plugin_base.stop_test_class(ctx) diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index b91fa4d50..2590f3b1e 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -8,8 +8,8 @@ """Testing extensions. this module is designed to work as a testing-framework-agnostic library, -so that we can continue to support nose and also begin adding new functionality -via py.test. +so that we can continue to support nose and also begin adding new +functionality via py.test. """ @@ -50,50 +50,62 @@ logging = None db_opts = {} options = None + def setup_options(make_option): make_option("--log-info", action="callback", type="string", callback=_log, - help="turn on info logging for <LOG> (multiple OK)") - make_option("--log-debug", action="callback", type="string", callback=_log, - help="turn on debug logging for <LOG> (multiple OK)") + help="turn on info logging for <LOG> (multiple OK)") + make_option("--log-debug", action="callback", + type="string", callback=_log, + help="turn on debug logging for <LOG> (multiple OK)") make_option("--db", action="append", type="string", dest="db", help="Use prefab database uri. Multiple OK, " - "first one is run by default.") + "first one is run by default.") make_option('--dbs', action='callback', callback=_list_dbs, - help="List available prefab dbs") + help="List available prefab dbs") make_option("--dburi", action="append", type="string", dest="dburi", - help="Database uri. Multiple OK, first one is run by default.") + help="Database uri. Multiple OK, " + "first one is run by default.") make_option("--dropfirst", action="store_true", dest="dropfirst", - help="Drop all tables in the target database first") + help="Drop all tables in the target database first") make_option("--backend-only", action="store_true", dest="backend_only", - help="Run only tests marked with __backend__") + help="Run only tests marked with __backend__") make_option("--mockpool", action="store_true", dest="mockpool", - help="Use mock pool (asserts only one connection used)") - make_option("--low-connections", action="store_true", dest="low_connections", - help="Use a low number of distinct connections - i.e. for Oracle TNS" - ) - make_option("--reversetop", action="store_true", dest="reversetop", default=False, - help="Use a random-ordering set implementation in the ORM (helps " - "reveal dependency issues)") + help="Use mock pool (asserts only one connection used)") + make_option("--low-connections", action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS") + make_option("--reversetop", action="store_true", + dest="reversetop", default=False, + help="Use a random-ordering set implementation in the ORM " + "(helps reveal dependency issues)") make_option("--requirements", action="callback", type="string", - callback=_requirements_opt, - help="requirements class for testing, overrides setup.cfg") - make_option("--with-cdecimal", action="store_true", dest="cdecimal", default=False, - help="Monkeypatch the cdecimal library into Python 'decimal' for all tests") - make_option("--serverside", action="callback", callback=_server_side_cursors, - help="Turn on server side cursors for PG") - make_option("--mysql-engine", action="store", dest="mysql_engine", default=None, - help="Use the specified MySQL storage engine for all tables, default is " - "a db-default/InnoDB combo.") + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg") + make_option("--with-cdecimal", action="store_true", + dest="cdecimal", default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' " + "for all tests") + make_option("--serverside", action="callback", + callback=_server_side_cursors, + help="Turn on server side cursors for PG") + make_option("--mysql-engine", action="store", + dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, " + "default is a db-default/InnoDB combo.") make_option("--tableopts", action="append", dest="tableopts", default=[], - help="Add a dialect-specific table option, key=value") - make_option("--write-profiles", action="store_true", dest="write_profiles", default=False, - help="Write/update profiling data.") + help="Add a dialect-specific table option, key=value") + make_option("--write-profiles", action="store_true", + dest="write_profiles", default=False, + help="Write/update profiling data.") + def read_config(): global file_config file_config = configparser.ConfigParser() file_config.read(['setup.cfg', 'test.cfg']) + def pre_begin(opt): """things to set up early, before coverage might be setup.""" global options @@ -101,9 +113,11 @@ def pre_begin(opt): for fn in pre_configure: fn(options, file_config) + def set_coverage_flag(value): options.has_coverage = value + def post_begin(): """things to set up later, once we know coverage is running.""" # Lazy setup of other options (post coverage) @@ -113,11 +127,11 @@ def post_begin(): # late imports, has to happen after config as well # as nose plugins like coverage global util, fixtures, engines, exclusions, \ - assertions, warnings, profiling,\ - config, testing + assertions, warnings, profiling,\ + config, testing from sqlalchemy import testing from sqlalchemy.testing import fixtures, engines, exclusions, \ - assertions, warnings, profiling, config + assertions, warnings, profiling, config from sqlalchemy import util @@ -143,6 +157,7 @@ def _list_dbs(*args): def _server_side_cursors(opt_str, value, parser): db_opts['server_side_cursors'] = True + def _requirements_opt(opt_str, value, parser): _setup_requirements(value) @@ -189,8 +204,9 @@ def _engine_uri(options, file_config): for db in re.split(r'[,\s]+', db_token): if db not in file_config.options('db'): raise RuntimeError( - "Unknown URI specifier '%s'. Specify --dbs for known uris." - % db) + "Unknown URI specifier '%s'. " + "Specify --dbs for known uris." + % db) else: db_urls.append(file_config.get('db', db)) @@ -211,12 +227,14 @@ def _engine_pool(options, file_config): from sqlalchemy import pool db_opts['poolclass'] = pool.AssertionPool + @post def _requirements(options, file_config): requirement_cls = file_config.get('sqla_testing', "requirement_cls") _setup_requirements(requirement_cls) + def _setup_requirements(argument): from sqlalchemy.testing import config from sqlalchemy import testing @@ -235,6 +253,7 @@ def _setup_requirements(argument): config.requirements = testing.requires = req_cls() + @post def _prep_testing_database(options, file_config): from sqlalchemy.testing import config @@ -250,27 +269,36 @@ def _prep_testing_database(options, file_config): pass else: for vname in view_names: - e.execute(schema._DropView(schema.Table(vname, schema.MetaData()))) + e.execute(schema._DropView( + schema.Table(vname, schema.MetaData()) + )) if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names(schema="test_schema") + view_names = inspector.get_view_names( + schema="test_schema") except NotImplementedError: pass else: for vname in view_names: e.execute(schema._DropView( - schema.Table(vname, - schema.MetaData(), schema="test_schema"))) + schema.Table(vname, schema.MetaData(), + schema="test_schema") + )) - for tname in reversed(inspector.get_table_names(order_by="foreign_key")): - e.execute(schema.DropTable(schema.Table(tname, schema.MetaData()))) + for tname in reversed(inspector.get_table_names( + order_by="foreign_key")): + e.execute(schema.DropTable( + schema.Table(tname, schema.MetaData()) + )) if config.requirements.schemas.enabled_for_config(cfg): for tname in reversed(inspector.get_table_names( - order_by="foreign_key", schema="test_schema")): + order_by="foreign_key", schema="test_schema")): e.execute(schema.DropTable( - schema.Table(tname, schema.MetaData(), schema="test_schema"))) + schema.Table(tname, schema.MetaData(), + schema="test_schema") + )) @post @@ -304,7 +332,7 @@ def _post_setup_options(opt, file_config): def _setup_profiling(options, file_config): from sqlalchemy.testing import profiling profiling._profile_stats = profiling.ProfileStatsFile( - file_config.get('sqla_testing', 'profile_file')) + file_config.get('sqla_testing', 'profile_file')) def want_class(cls): @@ -312,22 +340,24 @@ def want_class(cls): return False elif cls.__name__.startswith('_'): return False - elif config.options.backend_only and not getattr(cls, '__backend__', False): + elif config.options.backend_only and not getattr(cls, '__backend__', + False): return False else: return True + def generate_sub_tests(cls, module): if getattr(cls, '__backend__', False): for cfg in _possible_configs_for_cls(cls): name = "%s_%s_%s" % (cls.__name__, cfg.db.name, cfg.db.driver) subcls = type( - name, - (cls, ), - { - "__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)), - "__backend__": False} - ) + name, + (cls, ), + { + "__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)), + "__backend__": False} + ) setattr(module, name, subcls) yield subcls else: @@ -338,20 +368,24 @@ def start_test_class(cls): _do_skips(cls) _setup_engine(cls) + def stop_test_class(cls): engines.testing_reaper._stop_test_ctx() if not options.low_connections: assertions.global_cleanup_assertions() _restore_engine() + def _restore_engine(): config._current.reset(testing) + def _setup_engine(cls): if getattr(cls, '__engine_options__', None): eng = engines.testing_engine(options=cls.__engine_options__) config._current.push_engine(eng, testing) + def before_test(test, test_module_name, test_class, test_name): # like a nose id, e.g.: @@ -367,10 +401,12 @@ def before_test(test, test_module_name, test_class, test_name): warnings.resetwarnings() profiling._current_test = id_ + def after_test(test): engines.testing_reaper._after_test_ctx() warnings.resetwarnings() + def _possible_configs_for_cls(cls): all_configs = set(config.Config.all_configs()) if cls.__unsupported_on__: @@ -378,16 +414,14 @@ def _possible_configs_for_cls(cls): for config_obj in list(all_configs): if spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on__', None): spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) for config_obj in list(all_configs): if not spec(config_obj): all_configs.remove(config_obj) - - return all_configs + def _do_skips(cls): all_configs = _possible_configs_for_cls(cls) reasons = [] @@ -427,19 +461,17 @@ def _do_skips(cls): for config_obj in list(all_configs): if exclusions.skip_if( exclusions.SpecPredicate(db_spec, op, spec) - ).predicate(config_obj): + ).predicate(config_obj): all_configs.remove(config_obj) - - if not all_configs: raise SkipTest( "'%s' unsupported on DB implementation %s%s" % ( cls.__name__, - ", ".join("'%s' = %s" % ( - config_obj.db.name, - config_obj.db.dialect.server_version_info) - for config_obj in config.Config.all_configs() - ), + ", ".join("'%s' = %s" + % (config_obj.db.name, + config_obj.db.dialect.server_version_info) + for config_obj in config.Config.all_configs() + ), ", ".join(reasons) ) ) @@ -455,6 +487,6 @@ def _do_skips(cls): if config._current not in all_configs: _setup_config(all_configs.pop(), cls) + def _setup_config(config_obj, ctx): config._current.push(config_obj, testing) - diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 74d5cc083..11238bbac 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -4,6 +4,7 @@ import inspect from . import plugin_base import collections + def pytest_addoption(parser): group = parser.getgroup("sqlalchemy") @@ -11,7 +12,8 @@ def pytest_addoption(parser): callback_ = kw.pop("callback", None) if callback_: class CallableAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser, namespace, + values, option_string=None): callback_(option_string, values, parser) kw["action"] = CallableAction @@ -20,10 +22,12 @@ def pytest_addoption(parser): plugin_base.setup_options(make_option) plugin_base.read_config() + def pytest_configure(config): plugin_base.pre_begin(config.option) - plugin_base.set_coverage_flag(bool(getattr(config.option, "cov_source", False))) + plugin_base.set_coverage_flag(bool(getattr(config.option, + "cov_source", False))) plugin_base.post_begin() @@ -42,12 +46,14 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict(list) test_classes = set(item.parent for item in items) for test_class in test_classes: - for sub_cls in plugin_base.generate_sub_tests(test_class.cls, test_class.parent.module): + for sub_cls in plugin_base.generate_sub_tests( + test_class.cls, test_class.parent.module): if sub_cls is not test_class.cls: list_ = rebuilt_items[test_class.cls] - for inst in pytest.Class(sub_cls.__name__, - parent=test_class.parent.parent).collect(): + for inst in pytest.Class( + sub_cls.__name__, + parent=test_class.parent.parent).collect(): list_.extend(inst.collect()) newitems = [] @@ -61,12 +67,10 @@ def pytest_collection_modifyitems(session, config, items): # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) items[:] = sorted(newitems, key=lambda item: ( - item.parent.parent.parent.name, - item.parent.parent.name, - item.name - ) - ) - + item.parent.parent.parent.name, + item.parent.parent.name, + item.name + )) def pytest_pycollect_makeitem(collector, name, obj): @@ -82,6 +86,7 @@ def pytest_pycollect_makeitem(collector, name, obj): _current_class = None + def pytest_runtest_setup(item): # here we seem to get called only based on what we collected # in pytest_collection_modifyitems. So to do class-based stuff @@ -100,10 +105,12 @@ def pytest_runtest_setup(item): # this is needed for the class-level, to ensure that the # teardown runs after the class is completed with its own # class-level teardown... - item.parent.parent.addfinalizer(lambda: class_teardown(item.parent.parent)) + item.parent.parent.addfinalizer( + lambda: class_teardown(item.parent.parent)) test_setup(item) + def pytest_runtest_teardown(item): # ...but this works better as the hook here rather than # using a finalizer, as the finalizer seems to get in the way @@ -111,15 +118,19 @@ def pytest_runtest_teardown(item): # py.test assertion stuff instead) test_teardown(item) + def test_setup(item): - plugin_base.before_test(item, - item.parent.module.__name__, item.parent.cls, item.name) + plugin_base.before_test(item, item.parent.module.__name__, + item.parent.cls, item.name) + def test_teardown(item): plugin_base.after_test(item) + def class_setup(item): plugin_base.start_test_class(item.cls) + def class_teardown(item): plugin_base.stop_test_class(item.cls) diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index b818e4e15..75baec987 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -43,12 +43,12 @@ def profiled(target=None, **target_opts): """ profile_config = {'targets': set(), - 'report': True, - 'print_callers': False, - 'print_callees': False, - 'graphic': False, - 'sort': ('time', 'calls'), - 'limit': None} + 'report': True, + 'print_callers': False, + 'print_callees': False, + 'graphic': False, + 'sort': ('time', 'calls'), + 'limit': None} if target is None: target = 'anonymous_target' @@ -67,7 +67,7 @@ def profiled(target=None, **target_opts): limit = target_opts.get('limit', profile_config['limit']) print(("Profile report for target '%s'" % ( target, ) - )) + )) stats = load_stats() stats.sort_stats(*sort_) @@ -97,6 +97,7 @@ class ProfileStatsFile(object): so no json lib :( need to roll something silly """ + def __init__(self, filename): self.write = ( config.options is not None and @@ -177,19 +178,19 @@ class ProfileStatsFile(object): def _header(self): return \ - "# %s\n"\ - "# This file is written out on a per-environment basis.\n"\ - "# For each test in aaa_profiling, the corresponding function and \n"\ - "# environment is located within this file. If it doesn't exist,\n"\ - "# the test is skipped.\n"\ - "# If a callcount does exist, it is compared to what we received. \n"\ - "# assertions are raised if the counts do not match.\n"\ - "# \n"\ - "# To add a new callcount test, apply the function_call_count \n"\ - "# decorator and re-run the tests using the --write-profiles \n"\ - "# option - this file will be rewritten including the new count.\n"\ - "# \n"\ - "" % (self.fname) + "# %s\n"\ + "# This file is written out on a per-environment basis.\n"\ + "# For each test in aaa_profiling, the corresponding function and \n"\ + "# environment is located within this file. If it doesn't exist,\n"\ + "# the test is skipped.\n"\ + "# If a callcount does exist, it is compared to what we received. \n"\ + "# assertions are raised if the counts do not match.\n"\ + "# \n"\ + "# To add a new callcount test, apply the function_call_count \n"\ + "# decorator and re-run the tests using the --write-profiles \n"\ + "# option - this file will be rewritten including the new count.\n"\ + "# \n"\ + "" % (self.fname) def _read(self): try: @@ -225,7 +226,6 @@ class ProfileStatsFile(object): profile_f.close() - def function_call_count(variance=0.05): """Assert a target for a test case's function call count. @@ -248,9 +248,9 @@ def function_call_count(variance=0.05): # (not a great idea but we have these in test_zoomark) fn(*args, **kw) raise SkipTest("No profiling stats available on this " - "platform for this function. Run tests with " - "--write-profiles to add statistics to %s for " - "this platform." % _profile_stats.short_fname) + "platform for this function. Run tests with " + "--write-profiles to add statistics to %s for " + "this platform." % _profile_stats.short_fname) gc_collect() @@ -267,12 +267,12 @@ def function_call_count(variance=0.05): line_no, expected_count = expected print(("Pstats calls: %d Expected %s" % ( - callcount, - expected_count - ) + callcount, + expected_count + ) )) stats.print_stats() - #stats.print_callers() + # stats.print_callers() if expected_count: deviance = int(callcount * variance) @@ -287,8 +287,8 @@ def function_call_count(variance=0.05): "of expected %s. Rerun with --write-profiles to " "regenerate this callcount." % ( - callcount, (variance * 100), - expected_count)) + callcount, (variance * 100), + expected_count)) return fn_result return update_wrapper(wrap, fn) return decorate diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 59578ce7f..3413c0d30 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -21,6 +21,7 @@ from . import exclusions class Requirements(object): pass + class SuiteRequirements(Requirements): @property @@ -64,9 +65,9 @@ class SuiteRequirements(Requirements): # somehow only_if([x, y]) isn't working here, negation/conjunctions # getting confused. return exclusions.only_if( - lambda: self.on_update_cascade.enabled or self.deferrable_fks.enabled - ) - + lambda: self.on_update_cascade.enabled or + self.deferrable_fks.enabled + ) @property def self_referential_foreign_keys(self): @@ -94,13 +95,17 @@ class SuiteRequirements(Requirements): @property def offset(self): - """target database can render OFFSET, or an equivalent, in a SELECT.""" + """target database can render OFFSET, or an equivalent, in a + SELECT. + """ return exclusions.open() @property def bound_limit_offset(self): - """target database can render LIMIT and/or OFFSET using a bound parameter""" + """target database can render LIMIT and/or OFFSET using a bound + parameter + """ return exclusions.open() @@ -159,17 +164,16 @@ class SuiteRequirements(Requirements): return exclusions.open() - @property def empty_inserts(self): """target platform supports INSERT with no values, i.e. INSERT DEFAULT VALUES or equivalent.""" return exclusions.only_if( - lambda config: config.db.dialect.supports_empty_insert or \ - config.db.dialect.supports_default_values, - "empty inserts not supported" - ) + lambda config: config.db.dialect.supports_empty_insert or + config.db.dialect.supports_default_values, + "empty inserts not supported" + ) @property def insert_from_select(self): @@ -182,9 +186,9 @@ class SuiteRequirements(Requirements): """target platform supports RETURNING.""" return exclusions.only_if( - lambda config: config.db.dialect.implicit_returning, - "'returning' not supported by database" - ) + lambda config: config.db.dialect.implicit_returning, + "'returning' not supported by database" + ) @property def duplicate_names_in_cursor_description(self): @@ -199,9 +203,9 @@ class SuiteRequirements(Requirements): UPPERCASE as case insensitive names.""" return exclusions.skip_if( - lambda config: not config.db.dialect.requires_name_normalize, - "Backend does not require denormalized names." - ) + lambda config: not config.db.dialect.requires_name_normalize, + "Backend does not require denormalized names." + ) @property def multivalues_inserts(self): @@ -209,10 +213,9 @@ class SuiteRequirements(Requirements): INSERT statement.""" return exclusions.skip_if( - lambda config: not config.db.dialect.supports_multivalues_insert, - "Backend does not support multirow inserts." - ) - + lambda config: not config.db.dialect.supports_multivalues_insert, + "Backend does not support multirow inserts." + ) @property def implements_get_lastrowid(self): @@ -260,8 +263,8 @@ class SuiteRequirements(Requirements): """Target database must support SEQUENCEs.""" return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences - ], "no sequence support") + lambda config: config.db.dialect.supports_sequences + ], "no sequence support") @property def sequences_optional(self): @@ -269,13 +272,9 @@ class SuiteRequirements(Requirements): as a means of generating new PK values.""" return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences and \ - config.db.dialect.sequences_optional - ], "no sequence support, or sequences not optional") - - - - + lambda config: config.db.dialect.supports_sequences and + config.db.dialect.sequences_optional + ], "no sequence support, or sequences not optional") @property def reflects_pk_names(self): @@ -339,7 +338,9 @@ class SuiteRequirements(Requirements): @property def unicode_ddl(self): - """Target driver must support some degree of non-ascii symbol names.""" + """Target driver must support some degree of non-ascii symbol + names. + """ return exclusions.closed() @property @@ -531,7 +532,6 @@ class SuiteRequirements(Requirements): return exclusions.closed() - @property def update_from(self): """Target must support UPDATE..FROM syntax""" @@ -587,7 +587,9 @@ class SuiteRequirements(Requirements): @property def unicode_connections(self): - """Target driver must support non-ASCII characters being passed at all.""" + """Target driver must support non-ASCII characters being passed at + all. + """ return exclusions.open() @property @@ -600,11 +602,12 @@ class SuiteRequirements(Requirements): """Test environment must allow ad-hoc engine/connection creation. DBs that scale poorly for many connections, even when closed, i.e. - Oracle, may use the "--low-connections" option which flags this requirement - as not present. + Oracle, may use the "--low-connections" option which flags this + requirement as not present. """ - return exclusions.skip_if(lambda config: config.options.low_connections) + return exclusions.skip_if( + lambda config: config.options.low_connections) def _has_mysql_on_windows(self, config): return False @@ -619,8 +622,8 @@ class SuiteRequirements(Requirements): @property def cextensions(self): return exclusions.skip_if( - lambda: not self._has_cextensions(), "C extensions not installed" - ) + lambda: not self._has_cextensions(), "C extensions not installed" + ) def _has_sqlite(self): from sqlalchemy import create_engine diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py index d0c9afeeb..df254520b 100644 --- a/lib/sqlalchemy/testing/runner.py +++ b/lib/sqlalchemy/testing/runner.py @@ -38,6 +38,7 @@ import nose def main(): nose.main(addplugins=[NoseSQLAlchemy()]) + def setup_py_test(): """Runner to use for the 'test_suite' entry of your setup.py. diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 2e2a9b5ee..1cb356dd7 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -72,7 +72,7 @@ def Column(*args, **kw): col = schema.Column(*args, **kw) if 'test_needs_autoincrement' in test_opts and \ - kw.get('primary_key', False): + kw.get('primary_key', False): # allow any test suite to pick up on this col.info['test_needs_autoincrement'] = True @@ -83,19 +83,16 @@ def Column(*args, **kw): def add_seq(c, tbl): c._init_items( schema.Sequence(_truncate_name( - config.db.dialect, tbl.name + '_' + c.name + '_seq'), + config.db.dialect, tbl.name + '_' + c.name + '_seq'), optional=True) ) event.listen(col, 'after_parent_attach', add_seq, propagate=True) return col - - - def _truncate_name(dialect, name): if len(name) > dialect.max_identifier_length: return name[0:max(dialect.max_identifier_length - 6, 0)] + \ - "_" + hex(hash(name) % 64)[2:] + "_" + hex(hash(name) % 64)[2:] else: return name diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py index 2dca1443d..1d8010c8a 100644 --- a/lib/sqlalchemy/testing/suite/test_ddl.py +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -12,15 +12,17 @@ class TableDDLTest(fixtures.TestBase): def _simple_fixture(self): return Table('test_table', self.metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True, + autoincrement=False), + Column('data', String(50)) + ) def _underscore_fixture(self): return Table('_test_table', self.metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('_data', String(50)) - ) + Column('id', Integer, primary_key=True, + autoincrement=False), + Column('_data', String(50)) + ) def _simple_roundtrip(self, table): with config.db.begin() as conn: diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 3444e15c8..92d3d93e5 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -21,15 +21,15 @@ class LastrowidTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', String(50)) + ) Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True, autoincrement=False), + Column('data', String(50)) + ) def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() @@ -59,8 +59,9 @@ class LastrowidTest(fixtures.TablesTest): ) # failed on pypy1.9 but seems to be OK on pypy 2.1 - #@exclusions.fails_if(lambda: util.pypy, "lastrowid not maintained after " - # "connection close") + # @exclusions.fails_if(lambda: util.pypy, + # "lastrowid not maintained after " + # "connection close") @requirements.dbapi_lastrowid def test_native_lastrowid_autoinc(self): r = config.db.execute( @@ -81,19 +82,19 @@ class InsertBehaviorTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, \ - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', String(50)) + ) Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True, autoincrement=False), + Column('data', String(50)) + ) def test_autoclose_on_insert(self): if requirements.returning.enabled: engine = engines.testing_engine( - options={'implicit_returning': False}) + options={'implicit_returning': False}) else: engine = config.db @@ -119,12 +120,12 @@ class InsertBehaviorTest(fixtures.TablesTest): def test_empty_insert(self): r = config.db.execute( self.tables.autoinc_pk.insert(), - ) + ) assert r.closed r = config.db.execute( - self.tables.autoinc_pk.select().\ - where(self.tables.autoinc_pk.c.id != None) + self.tables.autoinc_pk.select(). + where(self.tables.autoinc_pk.c.id != None) ) assert len(r.fetchall()) @@ -133,21 +134,20 @@ class InsertBehaviorTest(fixtures.TablesTest): def test_insert_from_select(self): table = self.tables.manual_pk config.db.execute( - table.insert(), - [ - dict(id=1, data="data1"), - dict(id=2, data="data2"), - dict(id=3, data="data3"), - ] + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ] ) - config.db.execute( - table.insert(inline=True). - from_select( - ("id", "data",), select([table.c.id + 5, table.c.data]).where( - table.c.data.in_(["data2", "data3"])) - ), + table.insert(inline=True). + from_select(("id", "data",), + select([table.c.id + 5, table.c.data]). + where(table.c.data.in_(["data2", "data3"])) + ), ) eq_( @@ -158,6 +158,7 @@ class InsertBehaviorTest(fixtures.TablesTest): ("data3", ), ("data3", )] ) + class ReturningTest(fixtures.TablesTest): run_create_tables = 'each' __requires__ = 'returning', 'autoincrement_insert' @@ -175,10 +176,10 @@ class ReturningTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, \ - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', String(50)) + ) @requirements.fetch_rows_post_commit def test_explicit_returning_pk_autocommit(self): @@ -186,7 +187,7 @@ class ReturningTest(fixtures.TablesTest): table = self.tables.autoinc_pk r = engine.execute( table.insert().returning( - table.c.id), + table.c.id), data="some data" ) pk = r.first()[0] @@ -199,7 +200,7 @@ class ReturningTest(fixtures.TablesTest): with engine.begin() as conn: r = conn.execute( table.insert().returning( - table.c.id), + table.c.id), data="some data" ) pk = r.first()[0] diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 762c9955c..7cc5fd160 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -24,9 +24,9 @@ class HasTableTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) def test_has_table(self): with config.db.begin() as conn: @@ -34,8 +34,6 @@ class HasTableTest(fixtures.TablesTest): assert not config.db.dialect.has_table(conn, "nonexistent_table") - - class ComponentReflectionTest(fixtures.TablesTest): run_inserts = run_deletes = None @@ -56,41 +54,42 @@ class ComponentReflectionTest(fixtures.TablesTest): if testing.requires.self_referential_foreign_keys.enabled: users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - Column('parent_user_id', sa.Integer, - sa.ForeignKey('%susers.user_id' % schema_prefix)), - schema=schema, - test_needs_fk=True, - ) + Column('user_id', sa.INT, primary_key=True), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + Column('parent_user_id', sa.Integer, + sa.ForeignKey('%susers.user_id' % + schema_prefix)), + schema=schema, + test_needs_fk=True, + ) else: users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - schema=schema, - test_needs_fk=True, - ) + Column('user_id', sa.INT, primary_key=True), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) Table("dingalings", metadata, - Column('dingaling_id', sa.Integer, primary_key=True), - Column('address_id', sa.Integer, - sa.ForeignKey('%semail_addresses.address_id' % - schema_prefix)), - Column('data', sa.String(30)), - schema=schema, - test_needs_fk=True, - ) + Column('dingaling_id', sa.Integer, primary_key=True), + Column('address_id', sa.Integer, + sa.ForeignKey('%semail_addresses.address_id' % + schema_prefix)), + Column('data', sa.String(30)), + schema=schema, + test_needs_fk=True, + ) Table('email_addresses', metadata, - Column('address_id', sa.Integer), - Column('remote_user_id', sa.Integer, - sa.ForeignKey(users.c.user_id)), - Column('email_address', sa.String(20)), - sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'), - schema=schema, - test_needs_fk=True, - ) + Column('address_id', sa.Integer), + Column('remote_user_id', sa.Integer, + sa.ForeignKey(users.c.user_id)), + Column('email_address', sa.String(20)), + sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'), + schema=schema, + test_needs_fk=True, + ) if testing.requires.index_reflection.enabled: cls.define_index(metadata, users) @@ -110,7 +109,7 @@ class ComponentReflectionTest(fixtures.TablesTest): fullname = "%s.%s" % (schema, table_name) view_name = fullname + '_v' query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, fullname) + view_name, fullname) event.listen( metadata, @@ -146,7 +145,7 @@ class ComponentReflectionTest(fixtures.TablesTest): order_by=None): meta = self.metadata users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + self.tables.email_addresses, self.tables.dingalings insp = inspect(meta.bind) if table_type == 'view': table_names = insp.get_view_names(schema) @@ -195,13 +194,13 @@ class ComponentReflectionTest(fixtures.TablesTest): def _test_get_columns(self, schema=None, table_type='table'): meta = MetaData(testing.db) users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + self.tables.email_addresses, self.tables.dingalings table_names = ['users', 'email_addresses'] if table_type == 'view': table_names = ['users_v', 'email_addresses_v'] insp = inspect(meta.bind) for table_name, table in zip(table_names, (users, - addresses)): + addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) self.assert_(len(cols) > 0, len(cols)) @@ -218,23 +217,24 @@ class ComponentReflectionTest(fixtures.TablesTest): # Oracle returns Date for DateTime. if testing.against('oracle') and ctype_def \ - in (sql_types.Date, sql_types.DateTime): + in (sql_types.Date, sql_types.DateTime): ctype_def = sql_types.Date # assert that the desired type and return type share # a base within one of the generic types. self.assert_(len(set(ctype.__mro__). - intersection(ctype_def.__mro__).intersection([ - sql_types.Integer, - sql_types.Numeric, - sql_types.DateTime, - sql_types.Date, - sql_types.Time, - sql_types.String, - sql_types._Binary, - ])) > 0, '%s(%s), %s(%s)' % (col.name, - col.type, cols[i]['name'], ctype)) + intersection(ctype_def.__mro__). + intersection([ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ])) > 0, '%s(%s), %s(%s)' % + (col.name, col.type, cols[i]['name'], ctype)) if not col.primary_key: assert cols[i]['default'] is None @@ -246,11 +246,11 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _type_round_trip(self, *types): t = Table('t', self.metadata, - *[ - Column('t%d' % i, type_) - for i, type_ in enumerate(types) - ] - ) + *[ + Column('t%d' % i, type_) + for i, type_ in enumerate(types) + ] + ) t.create() return [ @@ -261,8 +261,8 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.table_reflection def test_numeric_reflection(self): for typ in self._type_round_trip( - sql_types.Numeric(18, 5), - ): + sql_types.Numeric(18, 5), + ): assert isinstance(typ, sql_types.Numeric) eq_(typ.precision, 18) eq_(typ.scale, 5) @@ -277,8 +277,8 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def test_nullable_reflection(self): t = Table('t', self.metadata, - Column('a', Integer, nullable=True), - Column('b', Integer, nullable=False)) + Column('a', Integer, nullable=True), + Column('b', Integer, nullable=False)) t.create() eq_( dict( @@ -288,7 +288,6 @@ class ComponentReflectionTest(fixtures.TablesTest): {"a": True, "b": False} ) - @testing.requires.table_reflection @testing.requires.schemas def test_get_columns_with_schema(self): @@ -311,11 +310,11 @@ class ComponentReflectionTest(fixtures.TablesTest): users_cons = insp.get_pk_constraint(users.name, schema=schema) users_pkeys = users_cons['constrained_columns'] - eq_(users_pkeys, ['user_id']) + eq_(users_pkeys, ['user_id']) addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) addr_pkeys = addr_cons['constrained_columns'] - eq_(addr_pkeys, ['address_id']) + eq_(addr_pkeys, ['address_id']) with testing.requires.reflects_pk_names.fail_if(): eq_(addr_cons['name'], 'email_ad_pk') @@ -347,7 +346,7 @@ class ComponentReflectionTest(fixtures.TablesTest): def _test_get_foreign_keys(self, schema=None): meta = self.metadata users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + self.tables.email_addresses, self.tables.dingalings insp = inspect(meta.bind) expected_schema = schema # users @@ -366,7 +365,7 @@ class ComponentReflectionTest(fixtures.TablesTest): if testing.requires.self_referential_foreign_keys.enabled: eq_(fkey1['constrained_columns'], ['parent_user_id']) - #addresses + # addresses addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] @@ -392,7 +391,7 @@ class ComponentReflectionTest(fixtures.TablesTest): def _test_get_indexes(self, schema=None): meta = self.metadata users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + self.tables.email_addresses, self.tables.dingalings # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. insp = inspect(meta.bind) @@ -421,7 +420,6 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_indexes_with_schema(self): self._test_get_indexes(schema='test_schema') - @testing.requires.unique_constraint_reflection def test_get_unique_constraints(self): self._test_get_unique_constraints() @@ -468,12 +466,11 @@ class ComponentReflectionTest(fixtures.TablesTest): for orig, refl in zip(uniques, reflected): eq_(orig, refl) - @testing.provide_metadata def _test_get_view_definition(self, schema=None): meta = self.metadata users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + self.tables.email_addresses, self.tables.dingalings view_name1 = 'users_v' view_name2 = 'email_addresses_v' insp = inspect(meta.bind) @@ -496,7 +493,7 @@ class ComponentReflectionTest(fixtures.TablesTest): def _test_get_table_oid(self, table_name, schema=None): meta = self.metadata users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + self.tables.email_addresses, self.tables.dingalings insp = inspect(meta.bind) oid = insp.get_table_oid(table_name, schema) self.assert_(isinstance(oid, int)) @@ -527,14 +524,13 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(meta.bind) for tname, cname in [ - ('users', 'user_id'), - ('email_addresses', 'address_id'), - ('dingalings', 'dingaling_id'), - ]: + ('users', 'user_id'), + ('email_addresses', 'address_id'), + ('dingalings', 'dingaling_id'), + ]: cols = insp.get_columns(tname) id_ = dict((c['name'], c) for c in cols)[cname] assert id_.get('autoincrement', True) - __all__ = ('ComponentReflectionTest', 'HasTableTest') diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 2fdab4d17..9ffaa6e04 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -15,13 +15,13 @@ class RowFetchTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) Table('has_dates', metadata, - Column('id', Integer, primary_key=True), - Column('today', DateTime) - ) + Column('id', Integer, primary_key=True), + Column('today', DateTime) + ) @classmethod def insert_data(cls): @@ -43,9 +43,9 @@ class RowFetchTest(fixtures.TablesTest): def test_via_string(self): row = config.db.execute( - self.tables.plain_pk.select().\ - order_by(self.tables.plain_pk.c.id) - ).first() + self.tables.plain_pk.select(). + order_by(self.tables.plain_pk.c.id) + ).first() eq_( row['id'], 1 @@ -56,9 +56,9 @@ class RowFetchTest(fixtures.TablesTest): def test_via_int(self): row = config.db.execute( - self.tables.plain_pk.select().\ - order_by(self.tables.plain_pk.c.id) - ).first() + self.tables.plain_pk.select(). + order_by(self.tables.plain_pk.c.id) + ).first() eq_( row[0], 1 @@ -69,9 +69,9 @@ class RowFetchTest(fixtures.TablesTest): def test_via_col_object(self): row = config.db.execute( - self.tables.plain_pk.select().\ - order_by(self.tables.plain_pk.c.id) - ).first() + self.tables.plain_pk.select(). + order_by(self.tables.plain_pk.c.id) + ).first() eq_( row[self.tables.plain_pk.c.id], 1 @@ -83,15 +83,14 @@ class RowFetchTest(fixtures.TablesTest): @requirements.duplicate_names_in_cursor_description def test_row_with_dupe_names(self): result = config.db.execute( - select([self.tables.plain_pk.c.data, - self.tables.plain_pk.c.data.label('data')]).\ - order_by(self.tables.plain_pk.c.id) - ) + select([self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label('data')]). + order_by(self.tables.plain_pk.c.id) + ) row = result.first() eq_(result.keys(), ['data', 'data']) eq_(row, ('d1', 'd1')) - def test_row_w_scalar_select(self): """test that a scalar select as a column is returned as such and that type conversion works OK. @@ -124,12 +123,13 @@ class PercentSchemaNamesTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): cls.tables.percent_table = Table('percent%table', metadata, - Column("percent%", Integer), - Column("spaces % more spaces", Integer), - ) - cls.tables.lightweight_percent_table = sql.table('percent%table', - sql.column("percent%"), - sql.column("spaces % more spaces"), + Column("percent%", Integer), + Column( + "spaces % more spaces", Integer), + ) + cls.tables.lightweight_percent_table = sql.table( + 'percent%table', sql.column("percent%"), + sql.column("spaces % more spaces") ) def test_single_roundtrip(self): @@ -152,8 +152,8 @@ class PercentSchemaNamesTest(fixtures.TablesTest): config.db.execute( percent_table.insert(), [{'percent%': 7, 'spaces % more spaces': 11}, - {'percent%': 9, 'spaces % more spaces': 10}, - {'percent%': 11, 'spaces % more spaces': 9}] + {'percent%': 9, 'spaces % more spaces': 10}, + {'percent%': 11, 'spaces % more spaces': 9}] ) self._assert_table() @@ -162,10 +162,10 @@ class PercentSchemaNamesTest(fixtures.TablesTest): lightweight_percent_table = self.tables.lightweight_percent_table for table in ( - percent_table, - percent_table.alias(), - lightweight_percent_table, - lightweight_percent_table.alias()): + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias()): eq_( list( config.db.execute( @@ -184,18 +184,18 @@ class PercentSchemaNamesTest(fixtures.TablesTest): list( config.db.execute( table.select(). - where(table.c['spaces % more spaces'].in_([9, 10])). - order_by(table.c['percent%']), + where(table.c['spaces % more spaces'].in_([9, 10])). + order_by(table.c['percent%']), ) ), - [ - (9, 10), - (11, 9) - ] + [ + (9, 10), + (11, 9) + ] ) - row = config.db.execute(table.select().\ - order_by(table.c['percent%'])).first() + row = config.db.execute(table.select(). + order_by(table.c['percent%'])).first() eq_(row['percent%'], 5) eq_(row['spaces % more spaces'], 12) @@ -211,9 +211,9 @@ class PercentSchemaNamesTest(fixtures.TablesTest): eq_( list( config.db.execute( - percent_table.\ - select().\ - order_by(percent_table.c['percent%']) + percent_table. + select(). + order_by(percent_table.c['percent%']) ) ), [(5, 15), (7, 15), (9, 15), (11, 15)] diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 3461b1e94..3f14ada05 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -21,12 +21,12 @@ class OrderByLabelTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('q', String(50)), - Column('p', String(50)) - ) + Column('id', Integer, primary_key=True), + Column('x', Integer), + Column('y', Integer), + Column('q', String(50)), + Column('p', String(50)) + ) @classmethod def insert_data(cls): @@ -86,15 +86,16 @@ class OrderByLabelTest(fixtures.TablesTest): [(7, ), (5, ), (3, )] ) + class LimitOffsetTest(fixtures.TablesTest): __backend__ = True @classmethod def define_tables(cls, metadata): Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) + Column('id', Integer, primary_key=True), + Column('x', Integer), + Column('y', Integer)) @classmethod def insert_data(cls): @@ -157,8 +158,8 @@ class LimitOffsetTest(fixtures.TablesTest): def test_bound_limit_offset(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id).\ - limit(bindparam("l")).offset(bindparam("o")), + select([table]).order_by(table.c.id). + limit(bindparam("l")).offset(bindparam("o")), [(2, 2, 3), (3, 3, 4)], params={"l": 2, "o": 1} ) diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index 6bc2822fc..bbb4ba65c 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -7,6 +7,7 @@ from ... import Integer, String, Sequence, schema from ..schema import Table, Column + class SequenceTest(fixtures.TablesTest): __requires__ = ('sequences',) __backend__ = True @@ -16,15 +17,15 @@ class SequenceTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('seq_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq'), primary_key=True), - Column('data', String(50)) - ) + Column('id', Integer, Sequence('tab_id_seq'), primary_key=True), + Column('data', String(50)) + ) Table('seq_opt_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq', optional=True), - primary_key=True), - Column('data', String(50)) - ) + Column('id', Integer, Sequence('tab_id_seq', optional=True), + primary_key=True), + Column('data', String(50)) + ) def test_insert_roundtrip(self): config.db.execute( @@ -62,7 +63,6 @@ class SequenceTest(fixtures.TablesTest): [1] ) - def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() eq_( @@ -80,7 +80,7 @@ class HasSequenceTest(fixtures.TestBase): testing.db.execute(schema.CreateSequence(s1)) try: eq_(testing.db.dialect.has_sequence(testing.db, - 'user_id_seq'), True) + 'user_id_seq'), True) finally: testing.db.execute(schema.DropSequence(s1)) @@ -89,8 +89,8 @@ class HasSequenceTest(fixtures.TestBase): s1 = Sequence('user_id_seq', schema="test_schema") testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, - 'user_id_seq', schema="test_schema"), True) + eq_(testing.db.dialect.has_sequence( + testing.db, 'user_id_seq', schema="test_schema"), True) finally: testing.db.execute(schema.DropSequence(s1)) @@ -101,7 +101,7 @@ class HasSequenceTest(fixtures.TestBase): @testing.requires.schemas def test_has_sequence_schemas_neg(self): eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema="test_schema"), + schema="test_schema"), False) @testing.requires.schemas @@ -110,7 +110,7 @@ class HasSequenceTest(fixtures.TestBase): testing.db.execute(schema.CreateSequence(s1)) try: eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema="test_schema"), + schema="test_schema"), False) finally: testing.db.execute(schema.DropSequence(s1)) @@ -124,5 +124,3 @@ class HasSequenceTest(fixtures.TestBase): False) finally: testing.db.execute(schema.DropSequence(s1)) - - diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 3a5134c96..230aeb1e9 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -5,7 +5,7 @@ from ..assertions import eq_ from ..config import requirements from sqlalchemy import Integer, Unicode, UnicodeText, select from sqlalchemy import Date, DateTime, Time, MetaData, String, \ - Text, Numeric, Float, literal, Boolean + Text, Numeric, Float, literal, Boolean from ..schema import Table, Column from ... import testing import decimal @@ -28,9 +28,9 @@ class _LiteralRoundTripFixture(object): for value in input_: ins = t.insert().values(x=literal(value)).compile( - dialect=testing.db.dialect, - compile_kwargs=dict(literal_binds=True) - ) + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True) + ) testing.db.execute(ins) for row in t.select().execute(): @@ -43,17 +43,17 @@ class _LiteralRoundTripFixture(object): class _UnicodeFixture(_LiteralRoundTripFixture): __requires__ = 'unicode_data', - data = u("Alors vous imaginez ma surprise, au lever du jour, "\ - "quand une drôle de petite voix m’a réveillé. Elle "\ - "disait: « S’il vous plaît… dessine-moi un mouton! »") + data = u("Alors vous imaginez ma surprise, au lever du jour, " + "quand une drôle de petite voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi un mouton! »") @classmethod def define_tables(cls, metadata): Table('unicode_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('unicode_data', cls.datatype), - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('unicode_data', cls.datatype), + ) def test_round_trip(self): unicode_table = self.tables.unicode_table @@ -66,10 +66,10 @@ class _UnicodeFixture(_LiteralRoundTripFixture): ) row = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) - ).first() + select([ + unicode_table.c.unicode_data, + ]) + ).first() eq_( row, @@ -91,10 +91,10 @@ class _UnicodeFixture(_LiteralRoundTripFixture): ) rows = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) - ).fetchall() + select([ + unicode_table.c.unicode_data, + ]) + ).fetchall() eq_( rows, [(self.data, ) for i in range(3)] @@ -110,8 +110,8 @@ class _UnicodeFixture(_LiteralRoundTripFixture): {"unicode_data": u('')} ) row = config.db.execute( - select([unicode_table.c.unicode_data]) - ).first() + select([unicode_table.c.unicode_data]) + ).first() eq_(row, (u(''),)) def test_literal(self): @@ -139,6 +139,7 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): def test_empty_strings_text(self): self._test_empty_strings() + class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): __requires__ = 'text_type', __backend__ = True @@ -146,10 +147,10 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('text_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('text_data', Text), - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('text_data', Text), + ) def test_text_roundtrip(self): text_table = self.tables.text_table @@ -159,8 +160,8 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): {"text_data": 'some text'} ) row = config.db.execute( - select([text_table.c.text_data]) - ).first() + select([text_table.c.text_data]) + ).first() eq_(row, ('some text',)) def test_text_empty_strings(self): @@ -171,8 +172,8 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): {"text_data": ''} ) row = config.db.execute( - select([text_table.c.text_data]) - ).first() + select([text_table.c.text_data]) + ).first() eq_(row, ('',)) def test_literal(self): @@ -186,6 +187,7 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): data = r'backslash one \ backslash two \\ end' self._literal_round_trip(Text, [data], [data]) + class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True @@ -194,7 +196,7 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): metadata = MetaData() foo = Table('foo', metadata, Column('one', String) - ) + ) foo.create(config.db) foo.drop(config.db) @@ -217,10 +219,10 @@ class _DateFixture(_LiteralRoundTripFixture): @classmethod def define_tables(cls, metadata): Table('date_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('date_data', cls.datatype), - ) + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('date_data', cls.datatype), + ) def test_round_trip(self): date_table = self.tables.date_table @@ -231,10 +233,10 @@ class _DateFixture(_LiteralRoundTripFixture): ) row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + select([ + date_table.c.date_data, + ]) + ).first() compare = self.compare or self.data eq_(row, @@ -250,10 +252,10 @@ class _DateFixture(_LiteralRoundTripFixture): ) row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + select([ + date_table.c.date_data, + ]) + ).first() eq_(row, (None,)) @testing.requires.datetime_literals @@ -262,7 +264,6 @@ class _DateFixture(_LiteralRoundTripFixture): self._literal_round_trip(self.datatype, [self.data], [compare]) - class DateTimeTest(_DateFixture, fixtures.TablesTest): __requires__ = 'datetime', __backend__ = True @@ -322,19 +323,22 @@ class DateHistoricTest(_DateFixture, fixtures.TablesTest): class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True + def test_literal(self): self._literal_round_trip(Integer, [5], [5]) + class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True @testing.emits_warning(r".*does \*not\* support Decimal objects natively") @testing.provide_metadata - def _do_test(self, type_, input_, output, filter_=None, check_scale=False): + def _do_test(self, type_, input_, output, + filter_=None, check_scale=False): metadata = self.metadata t = Table('t', metadata, Column('x', type_)) t.create() - t.insert().execute([{'x':x} for x in input_]) + t.insert().execute([{'x': x} for x in input_]) result = set([row[0] for row in t.select().execute()]) output = set(output) @@ -348,7 +352,6 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): [str(x) for x in output], ) - @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def test_render_literal_numeric(self): self._literal_round_trip( @@ -369,17 +372,16 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): self._literal_round_trip( Float(4), [15.7563, decimal.Decimal("15.7563")], - [15.7563,], + [15.7563, ], filter_=lambda n: n is not None and round(n, 5) or None ) - @testing.requires.precision_generic_float_type def test_float_custom_scale(self): self._do_test( Float(None, decimal_return_scale=7, asdecimal=True), [15.7563827, decimal.Decimal("15.7563827")], - [decimal.Decimal("15.7563827"),], + [decimal.Decimal("15.7563827"), ], check_scale=True ) @@ -421,7 +423,6 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): [decimal.Decimal("15.7563"), None], ) - def test_float_as_float(self): self._do_test( Float(precision=8), @@ -430,7 +431,6 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): filter_=lambda n: n is not None and round(n, 5) or None ) - @testing.requires.precision_numerics_general def test_precision_decimal(self): numbers = set([ @@ -445,7 +445,6 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): numbers, ) - @testing.requires.precision_numerics_enotation_large def test_enotation_decimal(self): """test exceedingly small decimals. @@ -475,7 +474,6 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): numbers ) - @testing.requires.precision_numerics_enotation_large def test_enotation_decimal_large(self): """test exceedingly large decimals. @@ -526,10 +524,10 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('boolean_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('value', Boolean), - Column('unconstrained_value', Boolean(create_constraint=False)), - ) + Column('id', Integer, primary_key=True, autoincrement=False), + Column('value', Boolean), + Column('unconstrained_value', Boolean(create_constraint=False)), + ) def test_render_literal_bool(self): self._literal_round_trip( @@ -551,11 +549,11 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) - ).first() + select([ + boolean_table.c.value, + boolean_table.c.unconstrained_value + ]) + ).first() eq_( row, @@ -576,11 +574,11 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) - ).first() + select([ + boolean_table.c.value, + boolean_table.c.unconstrained_value + ]) + ).first() eq_( row, @@ -588,11 +586,9 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): ) - - __all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', - 'DateTest', 'DateTimeTest', 'TextTest', - 'NumericTest', 'IntegerTest', - 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest', - 'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest', - 'DateHistoricTest', 'StringTest', 'BooleanTest') + 'DateTest', 'DateTimeTest', 'TextTest', + 'NumericTest', 'IntegerTest', + 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest', + 'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest', + 'DateHistoricTest', 'StringTest', 'BooleanTest') diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index 88dc95355..e4c61e74a 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -12,18 +12,18 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) @classmethod def insert_data(cls): config.db.execute( cls.tables.plain_pk.insert(), [ - {"id":1, "data":"d1"}, - {"id":2, "data":"d2"}, - {"id":3, "data":"d3"}, + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, ] ) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 35557582e..fc8390a79 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -61,8 +61,8 @@ def round_decimal(value, prec): # can also use shift() here but that is 2.6 only return (value * decimal.Decimal("1" + "0" * prec) - ).to_integral(decimal.ROUND_FLOOR) / \ - pow(10, prec) + ).to_integral(decimal.ROUND_FLOOR) / \ + pow(10, prec) class RandomSet(set): @@ -138,7 +138,7 @@ def function_named(fn, name): fn.__name__ = name except TypeError: fn = types.FunctionType(fn.__code__, fn.__globals__, name, - fn.__defaults__, fn.__closure__) + fn.__defaults__, fn.__closure__) return fn @@ -196,6 +196,7 @@ def provide_metadata(fn, *args, **kw): class adict(dict): """Dict keys available as attributes. Shadows.""" + def __getattribute__(self, key): try: return self[key] diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 33d338e8f..b3314de6e 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -12,6 +12,7 @@ from .. import exc as sa_exc from .. import util import re + def testing_warn(msg, stacklevel=3): """Replaces sqlalchemy.util.warn during tests.""" |
