diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/engine | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/engine')
| -rw-r--r-- | lib/sqlalchemy/engine/__init__.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 445 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 416 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 51 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 436 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 384 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/strategies.py | 120 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/threadlocal.py | 50 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/url.py | 115 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/util.py | 16 |
10 files changed, 1204 insertions, 859 deletions
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 6342b3c21..590359c38 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -57,10 +57,9 @@ from .interfaces import ( Dialect, ExecutionContext, ExceptionContext, - # backwards compat Compiled, - TypeCompiler + TypeCompiler, ) from .base import ( @@ -82,9 +81,7 @@ from .result import ( RowProxy, ) -from .util import ( - connection_memoize -) +from .util import connection_memoize from . import util, strategies @@ -92,7 +89,7 @@ from . import util, strategies # backwards compat from ..sql import ddl -default_strategy = 'plain' +default_strategy = "plain" def create_engine(*args, **kwargs): @@ -460,12 +457,12 @@ def create_engine(*args, **kwargs): """ - strategy = kwargs.pop('strategy', default_strategy) + strategy = kwargs.pop("strategy", default_strategy) strategy = strategies.strategies[strategy] return strategy.create(*args, **kwargs) -def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): +def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): """Create a new Engine instance using a configuration dictionary. The dictionary is typically produced from a config file. @@ -497,16 +494,15 @@ def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): """ - options = dict((key[len(prefix):], configuration[key]) - for key in configuration - if key.startswith(prefix)) - options['_coerce_config'] = True + options = dict( + (key[len(prefix) :], configuration[key]) + for key in configuration + if key.startswith(prefix) + ) + options["_coerce_config"] = True options.update(kwargs) - url = options.pop('url') + url = options.pop("url") return create_engine(url, **options) -__all__ = ( - 'create_engine', - 'engine_from_config', -) +__all__ = ("create_engine", "engine_from_config") diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4a057ee59..75d03b744 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -61,10 +61,16 @@ class Connection(Connectable): """ - def __init__(self, engine, connection=None, close_with_result=False, - _branch_from=None, _execution_options=None, - _dispatch=None, - _has_events=None): + def __init__( + self, + engine, + connection=None, + close_with_result=False, + _branch_from=None, + _execution_options=None, + _dispatch=None, + _has_events=None, + ): """Construct a new Connection. The constructor here is not public and is only called only by an @@ -86,8 +92,11 @@ class Connection(Connectable): self._has_events = _branch_from._has_events self.schema_for_object = _branch_from.schema_for_object else: - self.__connection = connection \ - if connection is not None else engine.raw_connection() + self.__connection = ( + connection + if connection is not None + else engine.raw_connection() + ) self.__transaction = None self.__savepoint_seq = 0 self.should_close_with_result = close_with_result @@ -101,7 +110,8 @@ class Connection(Connectable): # want to handle any of the engine's events in that case. self.dispatch = self.dispatch._join(engine.dispatch) self._has_events = _has_events or ( - _has_events is None and engine._has_events) + _has_events is None and engine._has_events + ) assert not _execution_options self._execution_options = engine._execution_options @@ -134,7 +144,8 @@ class Connection(Connectable): _branch_from=self, _execution_options=self._execution_options, _has_events=self._has_events, - _dispatch=self.dispatch) + _dispatch=self.dispatch, + ) @property def _root(self): @@ -322,8 +333,10 @@ class Connection(Connectable): def closed(self): """Return True if this connection is closed.""" - return '_Connection__connection' not in self.__dict__ \ + return ( + "_Connection__connection" not in self.__dict__ and not self.__can_reconnect + ) @property def invalidated(self): @@ -425,7 +438,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " - "transaction is rolled back") + "transaction is rolled back" + ) self.__connection = self.engine.raw_connection(_connection=self) self.__invalid = False return self.__connection @@ -437,14 +451,15 @@ class Connection(Connectable): # dialect initializer, where the connection is not wrapped in # _ConnectionFairy - return getattr(self.__connection, 'is_valid', False) + return getattr(self.__connection, "is_valid", False) @property def _still_open_and_connection_is_valid(self): - return \ - not self.closed and \ - not self.invalidated and \ - getattr(self.__connection, 'is_valid', False) + return ( + not self.closed + and not self.invalidated + and getattr(self.__connection, "is_valid", False) + ) @property def info(self): @@ -656,7 +671,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " - "is already in progress.") + "is already in progress." + ) if xid is None: xid = self.engine.dialect.create_xid() self.__transaction = TwoPhaseTransaction(self, xid) @@ -705,8 +721,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None else: @@ -725,8 +743,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None @@ -738,7 +758,7 @@ class Connection(Connectable): if name is None: self.__savepoint_seq += 1 - name = 'sa_savepoint_%s' % self.__savepoint_seq + name = "sa_savepoint_%s" % self.__savepoint_seq if self._still_open_and_connection_is_valid: self.engine.dialect.do_savepoint(self, name) return name @@ -797,7 +817,8 @@ class Connection(Connectable): assert isinstance(self.__transaction, TwoPhaseTransaction) try: self.engine.dialect.do_rollback_twophase( - self, xid, is_prepared) + self, xid, is_prepared + ) finally: if self.connection._reset_agent is self.__transaction: self.connection._reset_agent = None @@ -950,16 +971,16 @@ class Connection(Connectable): def _execute_function(self, func, multiparams, params): """Execute a sql.FunctionElement object.""" - return self._execute_clauseelement(func.select(), - multiparams, params) + return self._execute_clauseelement(func.select(), multiparams, params) def _execute_default(self, default, multiparams, params): """Execute a schema.ColumnDefault object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - default, multiparams, params = \ - fn(self, default, multiparams, params) + default, multiparams, params = fn( + self, default, multiparams, params + ) try: try: @@ -972,8 +993,7 @@ class Connection(Connectable): conn = self._revalidate_connection() dialect = self.dialect - ctx = dialect.execution_ctx_cls._init_default( - dialect, self, conn) + ctx = dialect.execution_ctx_cls._init_default(dialect, self, conn) except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @@ -982,8 +1002,9 @@ class Connection(Connectable): self.close() if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - default, multiparams, params, ret) + self.dispatch.after_execute( + self, default, multiparams, params, ret + ) return ret @@ -992,25 +1013,25 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - ddl, multiparams, params = \ - fn(self, ddl, multiparams, params) + ddl, multiparams, params = fn(self, ddl, multiparams, params) dialect = self.dialect compiled = ddl.compile( dialect=dialect, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_ddl, compiled, None, - compiled + compiled, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - ddl, multiparams, params, ret) + self.dispatch.after_execute(self, ddl, multiparams, params, ret) return ret def _execute_clauseelement(self, elem, multiparams, params): @@ -1018,8 +1039,7 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - elem, multiparams, params = \ - fn(self, elem, multiparams, params) + elem, multiparams, params = fn(self, elem, multiparams, params) distilled_params = _distill_params(multiparams, params) if distilled_params: @@ -1030,38 +1050,45 @@ class Connection(Connectable): keys = [] dialect = self.dialect - if 'compiled_cache' in self._execution_options: + if "compiled_cache" in self._execution_options: key = ( - dialect, elem, tuple(sorted(keys)), + dialect, + elem, + tuple(sorted(keys)), self.schema_for_object.hash_key, - len(distilled_params) > 1 + len(distilled_params) > 1, ) - compiled_sql = self._execution_options['compiled_cache'].get(key) + compiled_sql = self._execution_options["compiled_cache"].get(key) if compiled_sql is None: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None + if not self.schema_for_object.is_default + else None, ) - self._execution_options['compiled_cache'][key] = compiled_sql + self._execution_options["compiled_cache"][key] = compiled_sql else: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_compiled, compiled_sql, distilled_params, - compiled_sql, distilled_params + compiled_sql, + distilled_params, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - elem, multiparams, params, ret) + self.dispatch.after_execute(self, elem, multiparams, params, ret) return ret def _execute_compiled(self, compiled, multiparams, params): @@ -1069,8 +1096,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - compiled, multiparams, params = \ - fn(self, compiled, multiparams, params) + compiled, multiparams, params = fn( + self, compiled, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1079,11 +1107,13 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_compiled, compiled, parameters, - compiled, parameters + compiled, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - compiled, multiparams, params, ret) + self.dispatch.after_execute( + self, compiled, multiparams, params, ret + ) return ret def _execute_text(self, statement, multiparams, params): @@ -1091,8 +1121,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - statement, multiparams, params = \ - fn(self, statement, multiparams, params) + statement, multiparams, params = fn( + self, statement, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1101,16 +1132,18 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_statement, statement, parameters, - statement, parameters + statement, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - statement, multiparams, params, ret) + self.dispatch.after_execute( + self, statement, multiparams, params, ret + ) return ret - def _execute_context(self, dialect, constructor, - statement, parameters, - *args): + def _execute_context( + self, dialect, constructor, statement, parameters, *args + ): """Create an :class:`.ExecutionContext` and execute, returning a :class:`.ResultProxy`.""" @@ -1127,31 +1160,36 @@ class Connection(Connectable): context = constructor(dialect, self, conn, *args) except BaseException as e: self._handle_dbapi_exception( - e, - util.text_type(statement), parameters, - None, None) + e, util.text_type(statement), parameters, None, None + ) if context.compiled: context.pre_exec() - cursor, statement, parameters = context.cursor, \ - context.statement, \ - context.parameters + cursor, statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) if not context.executemany: parameters = parameters[0] if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, context.executemany) + statement, parameters = fn( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info( - "%r", - sql_util._repr_params(parameters, batches=10) + "%r", sql_util._repr_params(parameters, batches=10) ) evt_handled = False @@ -1164,10 +1202,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_executemany( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) elif not parameters and context.no_parameters: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute_no_params: @@ -1176,9 +1212,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute_no_params( - cursor, - statement, - context) + cursor, statement, context + ) else: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute: @@ -1187,24 +1222,22 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - context.executemany) + self.dispatch.after_cursor_execute( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if context.compiled: context.post_exec() @@ -1245,39 +1278,32 @@ class Connection(Connectable): """ if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, - False) + statement, parameters = fn( + self, cursor, statement, parameters, context, False + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info("%r", parameters) try: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute: + for fn in ( + () + if not self.dialect._has_events + else self.dialect.dispatch.do_execute + ): if fn(cursor, statement, parameters, context): break else: - self.dialect.do_execute( - cursor, - statement, - parameters, - context) + self.dialect.do_execute(cursor, statement, parameters, context) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - False) + self.dispatch.after_cursor_execute( + self, cursor, statement, parameters, context, False + ) def _safe_close_cursor(self, cursor): """Close the given cursor, catching exceptions @@ -1289,17 +1315,15 @@ class Connection(Connectable): except Exception: # log the error through the connection pool's logger. self.engine.pool.logger.error( - "Error closing cursor", exc_info=True) + "Error closing cursor", exc_info=True + ) _reentrant_error = False _is_disconnect = False - def _handle_dbapi_exception(self, - e, - statement, - parameters, - cursor, - context): + def _handle_dbapi_exception( + self, e, statement, parameters, cursor, context + ): exc_info = sys.exc_info() if context and context.exception is None: @@ -1309,15 +1333,14 @@ class Connection(Connectable): if not self._is_disconnect: self._is_disconnect = ( - isinstance(e, self.dialect.dbapi.Error) and - not self.closed and - self.dialect.is_disconnect( + isinstance(e, self.dialect.dbapi.Error) + and not self.closed + and self.dialect.is_disconnect( e, self.__connection if not self.invalidated else None, - cursor) - ) or ( - is_exit_exception and not self.closed - ) + cursor, + ) + ) or (is_exit_exception and not self.closed) if context: context.is_disconnect = self._is_disconnect @@ -1326,20 +1349,24 @@ class Connection(Connectable): if self._reentrant_error: util.raise_from_cause( - exc.DBAPIError.instance(statement, - parameters, - e, - self.dialect.dbapi.Error, - dialect=self.dialect), - exc_info + exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.dbapi.Error, + dialect=self.dialect, + ), + exc_info, ) self._reentrant_error = True try: # non-DBAPI error - if we already got a context, # or there's no string statement, don't wrap it - should_wrap = isinstance(e, self.dialect.dbapi.Error) or \ - (statement is not None - and context is None and not is_exit_exception) + should_wrap = isinstance(e, self.dialect.dbapi.Error) or ( + statement is not None + and context is None + and not is_exit_exception + ) if should_wrap: sqlalchemy_exception = exc.DBAPIError.instance( @@ -1348,30 +1375,37 @@ class Connection(Connectable): e, self.dialect.dbapi.Error, connection_invalidated=self._is_disconnect, - dialect=self.dialect) + dialect=self.dialect, + ) else: sqlalchemy_exception = None newraise = None - if (self._has_events or self.engine._has_events) and \ - not self._execution_options.get( - 'skip_user_error_events', False): + if ( + self._has_events or self.engine._has_events + ) and not self._execution_options.get( + "skip_user_error_events", False + ): # legacy dbapi_error event if should_wrap and context: - self.dispatch.dbapi_error(self, - cursor, - statement, - parameters, - context, - e) + self.dispatch.dbapi_error( + self, cursor, statement, parameters, context, e + ) # new handle_error event ctx = ExceptionContextImpl( - e, sqlalchemy_exception, self.engine, - self, cursor, statement, - parameters, context, self._is_disconnect, - invalidate_pool_on_disconnect) + e, + sqlalchemy_exception, + self.engine, + self, + cursor, + statement, + parameters, + context, + self._is_disconnect, + invalidate_pool_on_disconnect, + ) for fn in self.dispatch.handle_error: try: @@ -1388,13 +1422,15 @@ class Connection(Connectable): if self._is_disconnect != ctx.is_disconnect: self._is_disconnect = ctx.is_disconnect if sqlalchemy_exception: - sqlalchemy_exception.connection_invalidated = \ + sqlalchemy_exception.connection_invalidated = ( ctx.is_disconnect + ) # set up potentially user-defined value for # invalidate pool. - invalidate_pool_on_disconnect = \ + invalidate_pool_on_disconnect = ( ctx.invalidate_pool_on_disconnect + ) if should_wrap and context: context.handle_dbapi_exception(e) @@ -1408,10 +1444,7 @@ class Connection(Connectable): if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1441,7 +1474,8 @@ class Connection(Connectable): None, e, dialect.dbapi.Error, - connection_invalidated=is_disconnect) + connection_invalidated=is_disconnect, + ) else: sqlalchemy_exception = None @@ -1449,8 +1483,17 @@ class Connection(Connectable): if engine._has_events: ctx = ExceptionContextImpl( - e, sqlalchemy_exception, engine, None, None, None, - None, None, is_disconnect, True) + e, + sqlalchemy_exception, + engine, + None, + None, + None, + None, + None, + is_disconnect, + True, + ) for fn in engine.dispatch.handle_error: try: # handler returns an exception; @@ -1463,18 +1506,15 @@ class Connection(Connectable): newraise = _raised break - if sqlalchemy_exception and \ - is_disconnect != ctx.is_disconnect: - sqlalchemy_exception.connection_invalidated = \ - is_disconnect = ctx.is_disconnect + if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = ( + is_disconnect + ) = ctx.is_disconnect if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1545,16 +1585,25 @@ class Connection(Connectable): return callable_(self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, **kwargs): - visitorcallable(self.dialect, self, - **kwargs).traverse_single(element) + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) class ExceptionContextImpl(ExceptionContext): """Implement the :class:`.ExceptionContext` interface.""" - def __init__(self, exception, sqlalchemy_exception, - engine, connection, cursor, statement, parameters, - context, is_disconnect, invalidate_pool_on_disconnect): + def __init__( + self, + exception, + sqlalchemy_exception, + engine, + connection, + cursor, + statement, + parameters, + context, + is_disconnect, + invalidate_pool_on_disconnect, + ): self.engine = engine self.connection = connection self.sqlalchemy_exception = sqlalchemy_exception @@ -1691,12 +1740,14 @@ class NestedTransaction(Transaction): def _do_rollback(self): if self.is_active: self.connection._rollback_to_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) def _do_commit(self): if self.is_active: self.connection._release_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) class TwoPhaseTransaction(Transaction): @@ -1771,10 +1822,16 @@ class Engine(Connectable, log.Identified): """ - def __init__(self, pool, dialect, url, - logging_name=None, echo=None, proxy=None, - execution_options=None - ): + def __init__( + self, + pool, + dialect, + url, + logging_name=None, + echo=None, + proxy=None, + execution_options=None, + ): self.pool = pool self.url = url self.dialect = dialect @@ -1805,8 +1862,7 @@ class Engine(Connectable, log.Identified): :meth:`.Engine.execution_options` """ - self._execution_options = \ - self._execution_options.union(opt) + self._execution_options = self._execution_options.union(opt) self.dispatch.set_engine_execution_options(self, opt) self.dialect.set_engine_execution_options(self, opt) @@ -1894,7 +1950,7 @@ class Engine(Connectable, log.Identified): echo = log.echo_property() def __repr__(self): - return 'Engine(%r)' % self.url + return "Engine(%r)" % self.url def dispose(self): """Dispose of the connection pool used by this :class:`.Engine`. @@ -1934,8 +1990,9 @@ class Engine(Connectable, log.Identified): else: yield connection - def _run_visitor(self, visitorcallable, element, - connection=None, **kwargs): + def _run_visitor( + self, visitorcallable, element, connection=None, **kwargs + ): with self._optional_conn_ctx_manager(connection) as conn: conn._run_visitor(visitorcallable, element, **kwargs) @@ -2122,7 +2179,8 @@ class Engine(Connectable, log.Identified): self, self._wrap_pool_connect(self.pool.connect, None), close_with_result=close_with_result, - **kwargs) + **kwargs + ) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. @@ -2159,7 +2217,8 @@ class Engine(Connectable, log.Identified): except dialect.dbapi.Error as e: if connection is None: Connection._handle_dbapi_exception_noconnection( - e, dialect, self) + e, dialect, self + ) else: util.reraise(*sys.exc_info()) @@ -2185,7 +2244,8 @@ class Engine(Connectable, log.Identified): """ return self._wrap_pool_connect( - self.pool.unique_connection, _connection) + self.pool.unique_connection, _connection + ) class OptionEngine(Engine): @@ -2225,10 +2285,11 @@ class OptionEngine(Engine): pool = property(_get_pool, _set_pool) def _get_has_events(self): - return self._proxied._has_events or \ - self.__dict__.get('_has_events', False) + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) def _set_has_events(self, value): - self.__dict__['_has_events'] = value + self.__dict__["_has_events"] = value _has_events = property(_get_has_events, _set_has_events) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 028abc4c2..d7c2518fe 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -24,13 +24,11 @@ import weakref from .. import event AUTOCOMMIT_REGEXP = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE +) # When we're handed literal SQL, ensure it's a SELECT query -SERVER_SIDE_CURSOR_RE = re.compile( - r'\s*SELECT', - re.I | re.UNICODE) +SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) class DefaultDialect(interfaces.Dialect): @@ -68,16 +66,18 @@ class DefaultDialect(interfaces.Dialect): supports_simple_order_by_label = True - engine_config_types = util.immutabledict([ - ('convert_unicode', util.bool_or_str('force')), - ('pool_timeout', util.asint), - ('echo', util.bool_or_str('debug')), - ('echo_pool', util.bool_or_str('debug')), - ('pool_recycle', util.asint), - ('pool_size', util.asint), - ('max_overflow', util.asint), - ('pool_threadlocal', util.asbool), - ]) + engine_config_types = util.immutabledict( + [ + ("convert_unicode", util.bool_or_str("force")), + ("pool_timeout", util.asint), + ("echo", util.bool_or_str("debug")), + ("echo_pool", util.bool_or_str("debug")), + ("pool_recycle", util.asint), + ("pool_size", util.asint), + ("max_overflow", util.asint), + ("pool_threadlocal", util.asbool), + ] + ) # if the NUMERIC type # returns decimal.Decimal. @@ -93,9 +93,9 @@ class DefaultDialect(interfaces.Dialect): supports_unicode_statements = False supports_unicode_binds = False returns_unicode_strings = False - description_encoding = 'use_encoding' + description_encoding = "use_encoding" - name = 'default' + name = "default" # length at which to truncate # any identifier. @@ -111,7 +111,7 @@ class DefaultDialect(interfaces.Dialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True colspecs = {} - default_paramstyle = 'named' + default_paramstyle = "named" supports_default_values = False supports_empty_insert = True supports_multivalues_insert = False @@ -175,19 +175,26 @@ class DefaultDialect(interfaces.Dialect): """ - def __init__(self, convert_unicode=False, - encoding='utf-8', paramstyle=None, dbapi=None, - implicit_returning=None, - supports_right_nested_joins=None, - case_sensitive=True, - supports_native_boolean=None, - empty_in_strategy='static', - label_length=None, **kwargs): - - if not getattr(self, 'ported_sqla_06', True): + def __init__( + self, + convert_unicode=False, + encoding="utf-8", + paramstyle=None, + dbapi=None, + implicit_returning=None, + supports_right_nested_joins=None, + case_sensitive=True, + supports_native_boolean=None, + empty_in_strategy="static", + label_length=None, + **kwargs + ): + + if not getattr(self, "ported_sqla_06", True): util.warn( - "The %s dialect is not yet ported to the 0.6 format" % - self.name) + "The %s dialect is not yet ported to the 0.6 format" + % self.name + ) self.convert_unicode = convert_unicode self.encoding = encoding @@ -202,7 +209,7 @@ class DefaultDialect(interfaces.Dialect): self.paramstyle = self.default_paramstyle if implicit_returning is not None: self.implicit_returning = implicit_returning - self.positional = self.paramstyle in ('qmark', 'format', 'numeric') + self.positional = self.paramstyle in ("qmark", "format", "numeric") self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) if supports_right_nested_joins is not None: @@ -212,33 +219,33 @@ class DefaultDialect(interfaces.Dialect): self.case_sensitive = case_sensitive self.empty_in_strategy = empty_in_strategy - if empty_in_strategy == 'static': + if empty_in_strategy == "static": self._use_static_in = True - elif empty_in_strategy in ('dynamic', 'dynamic_warn'): + elif empty_in_strategy in ("dynamic", "dynamic_warn"): self._use_static_in = False - self._warn_on_empty_in = empty_in_strategy == 'dynamic_warn' + self._warn_on_empty_in = empty_in_strategy == "dynamic_warn" else: raise exc.ArgumentError( "empty_in_strategy may be 'static', " - "'dynamic', or 'dynamic_warn'") + "'dynamic', or 'dynamic_warn'" + ) if label_length and label_length > self.max_identifier_length: raise exc.ArgumentError( "Label length of %d is greater than this dialect's" - " maximum identifier length of %d" % - (label_length, self.max_identifier_length)) + " maximum identifier length of %d" + % (label_length, self.max_identifier_length) + ) self.label_length = label_length - if self.description_encoding == 'use_encoding': - self._description_decoder = \ - processors.to_unicode_processor_factory( - encoding - ) + if self.description_encoding == "use_encoding": + self._description_decoder = processors.to_unicode_processor_factory( + encoding + ) elif self.description_encoding is not None: - self._description_decoder = \ - processors.to_unicode_processor_factory( - self.description_encoding - ) + self._description_decoder = processors.to_unicode_processor_factory( + self.description_encoding + ) self._encoder = codecs.getencoder(self.encoding) self._decoder = processors.to_unicode_processor_factory(self.encoding) @@ -256,30 +263,35 @@ class DefaultDialect(interfaces.Dialect): @classmethod def get_pool_class(cls, url): - return getattr(cls, 'poolclass', pool.QueuePool) + return getattr(cls, "poolclass", pool.QueuePool) def initialize(self, connection): try: - self.server_version_info = \ - self._get_server_version_info(connection) + self.server_version_info = self._get_server_version_info( + connection + ) except NotImplementedError: self.server_version_info = None try: - self.default_schema_name = \ - self._get_default_schema_name(connection) + self.default_schema_name = self._get_default_schema_name( + connection + ) except NotImplementedError: self.default_schema_name = None try: - self.default_isolation_level = \ - self.get_isolation_level(connection.connection) + self.default_isolation_level = self.get_isolation_level( + connection.connection + ) except NotImplementedError: self.default_isolation_level = None self.returns_unicode_strings = self._check_unicode_returns(connection) - if self.description_encoding is not None and \ - self._check_unicode_description(connection): + if ( + self.description_encoding is not None + and self._check_unicode_description(connection) + ): self._description_decoder = self.description_encoding = None self.do_rollback(connection.connection) @@ -311,7 +323,8 @@ class DefaultDialect(interfaces.Dialect): def check_unicode(test): statement = cast_to( - expression.select([test]).compile(dialect=self)) + expression.select([test]).compile(dialect=self) + ) try: cursor = connection.connection.cursor() connection._cursor_execute(cursor, statement, parameters) @@ -320,8 +333,10 @@ class DefaultDialect(interfaces.Dialect): except exc.DBAPIError as de: # note that _cursor_execute() will have closed the cursor # if an exception is thrown. - util.warn("Exception attempting to " - "detect unicode returns: %r" % de) + util.warn( + "Exception attempting to " + "detect unicode returns: %r" % de + ) return False else: return isinstance(row[0], util.text_type) @@ -330,13 +345,13 @@ class DefaultDialect(interfaces.Dialect): # detect plain VARCHAR expression.cast( expression.literal_column("'test plain returns'"), - sqltypes.VARCHAR(60) + sqltypes.VARCHAR(60), ), # detect if there's an NVARCHAR type with different behavior # available expression.cast( expression.literal_column("'test unicode returns'"), - sqltypes.Unicode(60) + sqltypes.Unicode(60), ), ] @@ -364,9 +379,9 @@ class DefaultDialect(interfaces.Dialect): try: cursor.execute( cast_to( - expression.select([ - expression.literal_column("'x'").label("some_label") - ]).compile(dialect=self) + expression.select( + [expression.literal_column("'x'").label("some_label")] + ).compile(dialect=self) ) ) return isinstance(cursor.description[0][0], util.text_type) @@ -385,10 +400,12 @@ class DefaultDialect(interfaces.Dialect): return sqltypes.adapt_type(typeobj, self.colspecs) def reflecttable( - self, connection, table, include_columns, exclude_columns, **opts): + self, connection, table, include_columns, exclude_columns, **opts + ): insp = reflection.Inspector.from_engine(connection) return insp.reflecttable( - table, include_columns, exclude_columns, **opts) + table, include_columns, exclude_columns, **opts + ) def get_pk_constraint(self, conn, table_name, schema=None, **kw): """Compatibility method, adapts the result of get_primary_keys() @@ -396,16 +413,16 @@ class DefaultDialect(interfaces.Dialect): """ return { - 'constrained_columns': - self.get_primary_keys(conn, table_name, - schema=schema, **kw) + "constrained_columns": self.get_primary_keys( + conn, table_name, schema=schema, **kw + ) } def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( - "Identifier '%s' exceeds maximum length of %d characters" % - (ident, self.max_identifier_length) + "Identifier '%s' exceeds maximum length of %d characters" + % (ident, self.max_identifier_length) ) def connect(self, *cargs, **cparams): @@ -417,16 +434,16 @@ class DefaultDialect(interfaces.Dialect): return [[], opts] def set_engine_execution_options(self, engine, opts): - if 'isolation_level' in opts: - isolation_level = opts['isolation_level'] + if "isolation_level" in opts: + isolation_level = opts["isolation_level"] @event.listens_for(engine, "engine_connect") def set_isolation(connection, branch): if not branch: self._set_connection_isolation(connection, isolation_level) - if 'schema_translate_map' in opts: - getter = schema._schema_getter(opts['schema_translate_map']) + if "schema_translate_map" in opts: + getter = schema._schema_getter(opts["schema_translate_map"]) engine.schema_for_object = getter @event.listens_for(engine, "engine_connect") @@ -434,11 +451,11 @@ class DefaultDialect(interfaces.Dialect): connection.schema_for_object = getter def set_connection_execution_options(self, connection, opts): - if 'isolation_level' in opts: - self._set_connection_isolation(connection, opts['isolation_level']) + if "isolation_level" in opts: + self._set_connection_isolation(connection, opts["isolation_level"]) - if 'schema_translate_map' in opts: - getter = schema._schema_getter(opts['schema_translate_map']) + if "schema_translate_map" in opts: + getter = schema._schema_getter(opts["schema_translate_map"]) connection.schema_for_object = getter def _set_connection_isolation(self, connection, level): @@ -447,10 +464,12 @@ class DefaultDialect(interfaces.Dialect): "Connection is already established with a Transaction; " "setting isolation_level may implicitly rollback or commit " "the existing transaction, or have no effect until " - "next transaction") + "next transaction" + ) self.set_isolation_level(connection.connection, level) - connection.connection._connection_record.\ - finalize_callback.append(self.reset_isolation_level) + connection.connection._connection_record.finalize_callback.append( + self.reset_isolation_level + ) def do_begin(self, dbapi_connection): pass @@ -593,8 +612,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self @classmethod - def _init_compiled(cls, dialect, connection, dbapi_connection, - compiled, parameters): + def _init_compiled( + cls, dialect, connection, dbapi_connection, compiled, parameters + ): """Initialize execution context for a Compiled construct.""" self = cls.__new__(cls) @@ -609,16 +629,20 @@ class DefaultExecutionContext(interfaces.ExecutionContext): assert compiled.can_execute self.execution_options = compiled.execution_options.union( - connection._execution_options) + connection._execution_options + ) self.result_column_struct = ( - compiled._result_columns, compiled._ordered_columns, - compiled._textual_ordered_columns) + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + ) self.unicode_statement = util.text_type(compiled) if not dialect.supports_unicode_statements: self.statement = self.unicode_statement.encode( - self.dialect.encoding) + self.dialect.encoding + ) else: self.statement = self.unicode_statement @@ -630,9 +654,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not parameters: self.compiled_parameters = [compiled.construct_params()] else: - self.compiled_parameters = \ - [compiled.construct_params(m, _group_number=grp) for - grp, m in enumerate(parameters)] + self.compiled_parameters = [ + compiled.construct_params(m, _group_number=grp) + for grp, m in enumerate(parameters) + ] self.executemany = len(parameters) > 1 @@ -642,7 +667,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( - compiled.returning and not compiled.statement._returning) + compiled.returning and not compiled.statement._returning + ) if self.compiled.insert_prefetch or self.compiled.update_prefetch: if self.executemany: @@ -680,7 +706,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect._encoder(key)[0], processors[key](compiled_params[key]) if key in processors - else compiled_params[key] + else compiled_params[key], ) for key in compiled_params ) @@ -690,7 +716,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): key, processors[key](compiled_params[key]) if key in processors - else compiled_params[key] + else compiled_params[key], ) for key in compiled_params ) @@ -708,14 +734,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ if self.executemany: raise exc.InvalidRequestError( - "'expanding' parameters can't be used with " - "executemany()") + "'expanding' parameters can't be used with " "executemany()" + ) if self.compiled.positional and self.compiled._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric' raise NotImplementedError( "'expanding' bind parameters not supported with " - "'numeric' paramstyle at this time.") + "'numeric' paramstyle at this time." + ) self._expanded_parameters = {} @@ -729,7 +756,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): to_update_sets = {} for name in ( - self.compiled.positiontup if compiled.positional + self.compiled.positiontup + if compiled.positional else self.compiled.binds ): parameter = self.compiled.binds[name] @@ -748,12 +776,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not values: to_update = to_update_sets[name] = [] - replacement_expressions[name] = ( - self.compiled.visit_empty_set_expr( - parameter._expanding_in_types - if parameter._expanding_in_types - else [parameter.type] - ) + replacement_expressions[ + name + ] = self.compiled.visit_empty_set_expr( + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] ) elif isinstance(values[0], (tuple, list)): @@ -763,15 +791,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for j, value in enumerate(tuple_element, 1) ] replacement_expressions[name] = ", ".join( - "(%s)" % ", ".join( - self.compiled.bindtemplate % { - "name": - to_update[i * len(tuple_element) + j][0] + "(%s)" + % ", ".join( + self.compiled.bindtemplate + % { + "name": to_update[ + i * len(tuple_element) + j + ][0] } for j, value in enumerate(tuple_element) ) for i, tuple_element in enumerate(values) - ) else: to_update = to_update_sets[name] = [ @@ -779,20 +809,21 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for i, value in enumerate(values, 1) ] replacement_expressions[name] = ", ".join( - self.compiled.bindtemplate % { - "name": key} + self.compiled.bindtemplate % {"name": key} for key, value in to_update ) compiled_params.update(to_update) processors.update( (key, processors[name]) - for key, value in to_update if name in processors + for key, value in to_update + if name in processors ) if compiled.positional: positiontup.extend(name for name, value in to_update) self._expanded_parameters[name] = [ - expand_key for expand_key, value in to_update] + expand_key for expand_key, value in to_update + ] elif compiled.positional: positiontup.append(name) @@ -800,15 +831,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return replacement_expressions[m.group(1)] self.statement = re.sub( - r"\[EXPANDING_(\S+)\]", - process_expanding, - self.statement + r"\[EXPANDING_(\S+)\]", process_expanding, self.statement ) return positiontup @classmethod - def _init_statement(cls, dialect, connection, dbapi_connection, - statement, parameters): + def _init_statement( + cls, dialect, connection, dbapi_connection, statement, parameters + ): """Initialize execution context for a string SQL statement.""" self = cls.__new__(cls) @@ -836,13 +866,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for d in parameters ] or [{}] else: - self.parameters = [dialect.execute_sequence_format(p) - for p in parameters] + self.parameters = [ + dialect.execute_sequence_format(p) for p in parameters + ] self.executemany = len(parameters) > 1 - if not dialect.supports_unicode_statements and \ - isinstance(statement, util.text_type): + if not dialect.supports_unicode_statements and isinstance( + statement, util.text_type + ): self.unicode_statement = statement self.statement = dialect._encoder(statement)[0] else: @@ -890,11 +922,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): @util.memoized_property def should_autocommit(self): - autocommit = self.execution_options.get('autocommit', - not self.compiled and - self.statement and - expression.PARSE_AUTOCOMMIT - or False) + autocommit = self.execution_options.get( + "autocommit", + not self.compiled + and self.statement + and expression.PARSE_AUTOCOMMIT + or False, + ) if autocommit is expression.PARSE_AUTOCOMMIT: return self.should_autocommit_text(self.unicode_statement) @@ -912,8 +946,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ conn = self.root_connection - if isinstance(stmt, util.text_type) and \ - not self.dialect.supports_unicode_statements: + if ( + isinstance(stmt, util.text_type) + and not self.dialect.supports_unicode_statements + ): stmt = self.dialect._encoder(stmt)[0] if self.dialect.positional: @@ -926,8 +962,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if type_ is not None: # apply type post processors to the result proc = type_._cached_result_processor( - self.dialect, - self.cursor.description[0][1] + self.dialect, self.cursor.description[0][1] ) if proc: return proc(r) @@ -945,22 +980,30 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return False if self.dialect.server_side_cursors: - use_server_side = \ - self.execution_options.get('stream_results', True) and ( - (self.compiled and isinstance(self.compiled.statement, - expression.Selectable) - or - ( - (not self.compiled or - isinstance(self.compiled.statement, - expression.TextClause)) - and self.statement and SERVER_SIDE_CURSOR_RE.match( - self.statement)) - ) + use_server_side = self.execution_options.get( + "stream_results", True + ) and ( + ( + self.compiled + and isinstance( + self.compiled.statement, expression.Selectable + ) + or ( + ( + not self.compiled + or isinstance( + self.compiled.statement, expression.TextClause + ) + ) + and self.statement + and SERVER_SIDE_CURSOR_RE.match(self.statement) + ) ) + ) else: - use_server_side = \ - self.execution_options.get('stream_results', False) + use_server_side = self.execution_options.get( + "stream_results", False + ) return use_server_side @@ -1039,11 +1082,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.dialect.supports_sane_multi_rowcount def _setup_crud_result_proxy(self): - if self.isinsert and \ - not self.executemany: - if not self._is_implicit_returning and \ - not self.compiled.inline and \ - self.dialect.postfetch_lastrowid: + if self.isinsert and not self.executemany: + if ( + not self._is_implicit_returning + and not self.compiled.inline + and self.dialect.postfetch_lastrowid + ): self._setup_ins_pk_from_lastrowid() @@ -1087,12 +1131,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if autoinc_col is not None: # apply type post processors to the lastrowid proc = autoinc_col.type._cached_result_processor( - self.dialect, None) + self.dialect, None + ) if proc is not None: lastrowid = proc(lastrowid) self.inserted_primary_key = [ - lastrowid if c is autoinc_col else - compiled_params.get(key_getter(c), None) + lastrowid + if c is autoinc_col + else compiled_params.get(key_getter(c), None) for c in table.primary_key ] else: @@ -1108,8 +1154,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): table = self.compiled.statement.table compiled_params = self.compiled_parameters[0] self.inserted_primary_key = [ - compiled_params.get(key_getter(c), None) - for c in table.primary_key + compiled_params.get(key_getter(c), None) for c in table.primary_key ] def _setup_ins_pk_from_implicit_returning(self, row): @@ -1129,11 +1174,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ] def lastrow_has_defaults(self): - return (self.isinsert or self.isupdate) and \ - bool(self.compiled.postfetch) + return (self.isinsert or self.isupdate) and bool( + self.compiled.postfetch + ) def set_input_sizes( - self, translate=None, include_types=None, exclude_types=None): + self, translate=None, include_types=None, exclude_types=None + ): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. @@ -1143,7 +1190,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ - if not hasattr(self.compiled, 'bind_names'): + if not hasattr(self.compiled, "bind_names"): return inputsizes = {} @@ -1153,12 +1200,18 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect_impl_cls = type(dialect_impl) dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi) - if dbtype is not None and ( - not exclude_types or dbtype not in exclude_types and - dialect_impl_cls not in exclude_types - ) and ( - not include_types or dbtype in include_types or - dialect_impl_cls in include_types + if ( + dbtype is not None + and ( + not exclude_types + or dbtype not in exclude_types + and dialect_impl_cls not in exclude_types + ) + and ( + not include_types + or dbtype in include_types + or dialect_impl_cls in include_types + ) ): inputsizes[bindparam] = dbtype else: @@ -1177,14 +1230,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if dbtype is not None: if key in self._expanded_parameters: positional_inputsizes.extend( - [dbtype] * len(self._expanded_parameters[key])) + [dbtype] * len(self._expanded_parameters[key]) + ) else: positional_inputsizes.append(dbtype) try: self.cursor.setinputsizes(*positional_inputsizes) except BaseException as e: self.root_connection._handle_dbapi_exception( - e, None, None, None, self) + e, None, None, None, self + ) else: keyword_inputsizes = {} for bindparam, key in self.compiled.bind_names.items(): @@ -1199,8 +1254,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): key = self.dialect._encoder(key)[0] if key in self._expanded_parameters: keyword_inputsizes.update( - (expand_key, dbtype) for expand_key - in self._expanded_parameters[key] + (expand_key, dbtype) + for expand_key in self._expanded_parameters[key] ) else: keyword_inputsizes[key] = dbtype @@ -1208,7 +1263,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor.setinputsizes(**keyword_inputsizes) except BaseException as e: self.root_connection._handle_dbapi_exception( - e, None, None, None, self) + e, None, None, None, self + ) def _exec_default(self, column, default, type_): if default.is_sequence: @@ -1290,10 +1346,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): except AttributeError: raise exc.InvalidRequestError( "get_current_parameters() can only be invoked in the " - "context of a Python side column default function") - if isolate_multiinsert_groups and \ - self.isinsert and \ - self.compiled.statement._has_multi_parameters: + "context of a Python side column default function" + ) + if ( + isolate_multiinsert_groups + and self.isinsert + and self.compiled.statement._has_multi_parameters + ): if column._is_multiparam_column: index = column.index + 1 d = {column.original.key: parameters[column.key]} @@ -1302,8 +1361,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): index = 0 keys = self.compiled.statement.parameters[0].keys() d.update( - (key, parameters["%s_m%d" % (key, index)]) - for key in keys + (key, parameters["%s_m%d" % (key, index)]) for key in keys ) return d else: @@ -1360,12 +1418,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - self.current_parameters = compiled_parameters = \ - self.compiled_parameters[0] + self.current_parameters = ( + compiled_parameters + ) = self.compiled_parameters[0] for c in self.compiled.insert_prefetch: - if c.default and \ - not c.default.is_sequence and c.default.is_scalar: + if c.default and not c.default.is_sequence and c.default.is_scalar: val = c.default.arg else: val = self.get_insert_default(c) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 9c3b24e9a..e10e6e884 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -198,7 +198,8 @@ class Dialect(object): pass def reflecttable( - self, connection, table, include_columns, exclude_columns): + self, connection, table, include_columns, exclude_columns + ): """Load table description from the database. Given a :class:`.Connection` and a @@ -367,7 +368,8 @@ class Dialect(object): raise NotImplementedError() def get_unique_constraints( - self, connection, table_name, schema=None, **kw): + self, connection, table_name, schema=None, **kw + ): r"""Return information about unique constraints in `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -389,8 +391,7 @@ class Dialect(object): raise NotImplementedError() - def get_check_constraints( - self, connection, table_name, schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): r"""Return information about check constraints in `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -412,8 +413,7 @@ class Dialect(object): raise NotImplementedError() - def get_table_comment( - self, connection, table_name, schema=None, **kw): + def get_table_comment(self, connection, table_name, schema=None, **kw): r"""Return the "comment" for the table identified by `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -613,8 +613,9 @@ class Dialect(object): raise NotImplementedError() - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): """Rollback a two phase transaction on the given connection. :param connection: a :class:`.Connection`. @@ -627,8 +628,9 @@ class Dialect(object): raise NotImplementedError() - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): """Commit a two phase transaction on the given connection. @@ -664,8 +666,9 @@ class Dialect(object): raise NotImplementedError() - def do_execute_no_params(self, cursor, statement, parameters, - context=None): + def do_execute_no_params( + self, cursor, statement, parameters, context=None + ): """Provide an implementation of ``cursor.execute(statement)``. The parameter collection should not be sent. @@ -899,6 +902,7 @@ class CreateEnginePlugin(object): .. versionadded:: 1.1 """ + def __init__(self, url, kwargs): """Contruct a new :class:`.CreateEnginePlugin`. @@ -1129,20 +1133,24 @@ class Connectable(object): raise NotImplementedError() - @util.deprecated("0.7", - "Use the create() method on the given schema " - "object directly, i.e. :meth:`.Table.create`, " - ":meth:`.Index.create`, :meth:`.MetaData.create_all`") + @util.deprecated( + "0.7", + "Use the create() method on the given schema " + "object directly, i.e. :meth:`.Table.create`, " + ":meth:`.Index.create`, :meth:`.MetaData.create_all`", + ) def create(self, entity, **kwargs): """Emit CREATE statements for the given schema entity. """ raise NotImplementedError() - @util.deprecated("0.7", - "Use the drop() method on the given schema " - "object directly, i.e. :meth:`.Table.drop`, " - ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`") + @util.deprecated( + "0.7", + "Use the drop() method on the given schema " + "object directly, i.e. :meth:`.Table.drop`, " + ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`", + ) def drop(self, entity, **kwargs): """Emit DROP statements for the given schema entity. """ @@ -1160,8 +1168,7 @@ class Connectable(object): """ raise NotImplementedError() - def _run_visitor(self, visitorcallable, element, - **kwargs): + def _run_visitor(self, visitorcallable, element, **kwargs): raise NotImplementedError() def _execute_clauseelement(self, elem, multiparams=None, params=None): diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 841bb4dfb..9b5fa2459 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -37,17 +37,17 @@ from .base import Connectable @util.decorator def cache(fn, self, con, *args, **kw): - info_cache = kw.get('info_cache', None) + info_cache = kw.get("info_cache", None) if info_cache is None: return fn(self, con, *args, **kw) key = ( fn.__name__, tuple(a for a in args if isinstance(a, util.string_types)), - tuple((k, v) for k, v in kw.items() if - isinstance(v, - util.string_types + util.int_types + (float, ) - ) - ) + tuple( + (k, v) + for k, v in kw.items() + if isinstance(v, util.string_types + util.int_types + (float,)) + ), ) ret = info_cache.get(key) if ret is None: @@ -99,7 +99,7 @@ class Inspector(object): self.bind = bind # set the engine - if hasattr(bind, 'engine'): + if hasattr(bind, "engine"): self.engine = bind.engine else: self.engine = bind @@ -130,7 +130,7 @@ class Inspector(object): See the example at :class:`.Inspector`. """ - if hasattr(bind.dialect, 'inspector'): + if hasattr(bind.dialect, "inspector"): return bind.dialect.inspector(bind) return Inspector(bind) @@ -153,9 +153,10 @@ class Inspector(object): """Return all schema names. """ - if hasattr(self.dialect, 'get_schema_names'): - return self.dialect.get_schema_names(self.bind, - info_cache=self.info_cache) + if hasattr(self.dialect, "get_schema_names"): + return self.dialect.get_schema_names( + self.bind, info_cache=self.info_cache + ) return [] def get_table_names(self, schema=None, order_by=None): @@ -196,17 +197,18 @@ class Inspector(object): """ - if hasattr(self.dialect, 'get_table_names'): + if hasattr(self.dialect, "get_table_names"): tnames = self.dialect.get_table_names( - self.bind, schema, info_cache=self.info_cache) + self.bind, schema, info_cache=self.info_cache + ) else: tnames = self.engine.table_names(schema) - if order_by == 'foreign_key': + if order_by == "foreign_key": tuples = [] for tname in tnames: for fkey in self.get_foreign_keys(tname, schema): - if tname != fkey['referred_table']: - tuples.append((fkey['referred_table'], tname)) + if tname != fkey["referred_table"]: + tuples.append((fkey["referred_table"], tname)) tnames = list(topological.sort(tuples, tnames)) return tnames @@ -234,9 +236,10 @@ class Inspector(object): with an already-given :class:`.MetaData`. """ - if hasattr(self.dialect, 'get_table_names'): + if hasattr(self.dialect, "get_table_names"): tnames = self.dialect.get_table_names( - self.bind, schema, info_cache=self.info_cache) + self.bind, schema, info_cache=self.info_cache + ) else: tnames = self.engine.table_names(schema) @@ -246,20 +249,17 @@ class Inspector(object): fknames_for_table = {} for tname in tnames: fkeys = self.get_foreign_keys(tname, schema) - fknames_for_table[tname] = set( - [fk['name'] for fk in fkeys] - ) + fknames_for_table[tname] = set([fk["name"] for fk in fkeys]) for fkey in fkeys: - if tname != fkey['referred_table']: - tuples.add((fkey['referred_table'], tname)) + if tname != fkey["referred_table"]: + tuples.add((fkey["referred_table"], tname)) try: candidate_sort = list(topological.sort(tuples, tnames)) except exc.CircularDependencyError as err: for edge in err.edges: tuples.remove(edge) remaining_fkcs.update( - (edge[1], fkc) - for fkc in fknames_for_table[edge[1]] + (edge[1], fkc) for fkc in fknames_for_table[edge[1]] ) candidate_sort = list(topological.sort(tuples, tnames)) @@ -278,7 +278,8 @@ class Inspector(object): """ return self.dialect.get_temp_table_names( - self.bind, info_cache=self.info_cache) + self.bind, info_cache=self.info_cache + ) def get_temp_view_names(self): """return a list of temporary view names for the current bind. @@ -290,7 +291,8 @@ class Inspector(object): """ return self.dialect.get_temp_view_names( - self.bind, info_cache=self.info_cache) + self.bind, info_cache=self.info_cache + ) def get_table_options(self, table_name, schema=None, **kw): """Return a dictionary of options specified when the table of the @@ -306,10 +308,10 @@ class Inspector(object): use :class:`.quoted_name`. """ - if hasattr(self.dialect, 'get_table_options'): + if hasattr(self.dialect, "get_table_options"): return self.dialect.get_table_options( - self.bind, table_name, schema, - info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) return {} def get_view_names(self, schema=None): @@ -320,8 +322,9 @@ class Inspector(object): """ - return self.dialect.get_view_names(self.bind, schema, - info_cache=self.info_cache) + return self.dialect.get_view_names( + self.bind, schema, info_cache=self.info_cache + ) def get_view_definition(self, view_name, schema=None): """Return definition for `view_name`. @@ -332,7 +335,8 @@ class Inspector(object): """ return self.dialect.get_view_definition( - self.bind, view_name, schema, info_cache=self.info_cache) + self.bind, view_name, schema, info_cache=self.info_cache + ) def get_columns(self, table_name, schema=None, **kw): """Return information about columns in `table_name`. @@ -364,18 +368,21 @@ class Inspector(object): """ - col_defs = self.dialect.get_columns(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + col_defs = self.dialect.get_columns( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) for col_def in col_defs: # make this easy and only return instances for coltype - coltype = col_def['type'] + coltype = col_def["type"] if not isinstance(coltype, TypeEngine): - col_def['type'] = coltype() + col_def["type"] = coltype() return col_defs - @deprecated('0.7', 'Call to deprecated method get_primary_keys.' - ' Use get_pk_constraint instead.') + @deprecated( + "0.7", + "Call to deprecated method get_primary_keys." + " Use get_pk_constraint instead.", + ) def get_primary_keys(self, table_name, schema=None, **kw): """Return information about primary keys in `table_name`. @@ -383,9 +390,9 @@ class Inspector(object): primary key information as a list of column names. """ - return self.dialect.get_pk_constraint(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw)['constrained_columns'] + return self.dialect.get_pk_constraint( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + )["constrained_columns"] def get_pk_constraint(self, table_name, schema=None, **kw): """Return information about primary key constraint on `table_name`. @@ -407,9 +414,9 @@ class Inspector(object): use :class:`.quoted_name`. """ - return self.dialect.get_pk_constraint(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + return self.dialect.get_pk_constraint( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_foreign_keys(self, table_name, schema=None, **kw): """Return information about foreign_keys in `table_name`. @@ -442,9 +449,9 @@ class Inspector(object): """ - return self.dialect.get_foreign_keys(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + return self.dialect.get_foreign_keys( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_indexes(self, table_name, schema=None, **kw): """Return information about indexes in `table_name`. @@ -476,9 +483,9 @@ class Inspector(object): """ - return self.dialect.get_indexes(self.bind, table_name, - schema, - info_cache=self.info_cache, **kw) + return self.dialect.get_indexes( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_unique_constraints(self, table_name, schema=None, **kw): """Return information about unique constraints in `table_name`. @@ -504,7 +511,8 @@ class Inspector(object): """ return self.dialect.get_unique_constraints( - self.bind, table_name, schema, info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_table_comment(self, table_name, schema=None, **kw): """Return information about the table comment for ``table_name``. @@ -523,8 +531,8 @@ class Inspector(object): """ return self.dialect.get_table_comment( - self.bind, table_name, schema, info_cache=self.info_cache, - **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_check_constraints(self, table_name, schema=None, **kw): """Return information about check constraints in `table_name`. @@ -550,10 +558,12 @@ class Inspector(object): """ return self.dialect.get_check_constraints( - self.bind, table_name, schema, info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) - def reflecttable(self, table, include_columns, exclude_columns=(), - _extend_on=None): + def reflecttable( + self, table, include_columns, exclude_columns=(), _extend_on=None + ): """Given a Table object, load its internal constructs based on introspection. @@ -599,7 +609,8 @@ class Inspector(object): # reflect table options, like mysql_engine tbl_opts = self.get_table_options( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) if tbl_opts: # add additional kwargs to the Table if the dialect # returned them @@ -615,185 +626,251 @@ class Inspector(object): cols_by_orig_name = {} for col_d in self.get_columns( - table_name, schema, **table.dialect_kwargs): + table_name, schema, **table.dialect_kwargs + ): found_table = True self._reflect_column( - table, col_d, include_columns, - exclude_columns, cols_by_orig_name) + table, + col_d, + include_columns, + exclude_columns, + cols_by_orig_name, + ) if not found_table: raise exc.NoSuchTableError(table.name) self._reflect_pk( - table_name, schema, table, cols_by_orig_name, exclude_columns) + table_name, schema, table, cols_by_orig_name, exclude_columns + ) self._reflect_fk( - table_name, schema, table, cols_by_orig_name, - exclude_columns, _extend_on, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + exclude_columns, + _extend_on, + reflection_options, + ) self._reflect_indexes( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_unique_constraints( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_check_constraints( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_table_comment( table_name, schema, table, reflection_options ) def _reflect_column( - self, table, col_d, include_columns, - exclude_columns, cols_by_orig_name): + self, table, col_d, include_columns, exclude_columns, cols_by_orig_name + ): - orig_name = col_d['name'] + orig_name = col_d["name"] table.dispatch.column_reflect(self, table, col_d) # fetch name again as column_reflect is allowed to # change it - name = col_d['name'] - if (include_columns and name not in include_columns) \ - or (exclude_columns and name in exclude_columns): + name = col_d["name"] + if (include_columns and name not in include_columns) or ( + exclude_columns and name in exclude_columns + ): return - coltype = col_d['type'] + coltype = col_d["type"] col_kw = dict( (k, col_d[k]) for k in [ - 'nullable', 'autoincrement', 'quote', 'info', 'key', - 'comment'] + "nullable", + "autoincrement", + "quote", + "info", + "key", + "comment", + ] if k in col_d ) - if 'dialect_options' in col_d: - col_kw.update(col_d['dialect_options']) + if "dialect_options" in col_d: + col_kw.update(col_d["dialect_options"]) colargs = [] - if col_d.get('default') is not None: - default = col_d['default'] + if col_d.get("default") is not None: + default = col_d["default"] if isinstance(default, sql.elements.TextClause): default = sa_schema.DefaultClause(default, _reflected=True) elif not isinstance(default, sa_schema.FetchedValue): default = sa_schema.DefaultClause( - sql.text(col_d['default']), _reflected=True) + sql.text(col_d["default"]), _reflected=True + ) colargs.append(default) - if 'sequence' in col_d: + if "sequence" in col_d: self._reflect_col_sequence(col_d, colargs) - cols_by_orig_name[orig_name] = col = \ - sa_schema.Column(name, coltype, *colargs, **col_kw) + cols_by_orig_name[orig_name] = col = sa_schema.Column( + name, coltype, *colargs, **col_kw + ) if col.key in table.primary_key: col.primary_key = True table.append_column(col) def _reflect_col_sequence(self, col_d, colargs): - if 'sequence' in col_d: + if "sequence" in col_d: # TODO: mssql and sybase are using this. - seq = col_d['sequence'] - sequence = sa_schema.Sequence(seq['name'], 1, 1) - if 'start' in seq: - sequence.start = seq['start'] - if 'increment' in seq: - sequence.increment = seq['increment'] + seq = col_d["sequence"] + sequence = sa_schema.Sequence(seq["name"], 1, 1) + if "start" in seq: + sequence.start = seq["start"] + if "increment" in seq: + sequence.increment = seq["increment"] colargs.append(sequence) def _reflect_pk( - self, table_name, schema, table, - cols_by_orig_name, exclude_columns): + self, table_name, schema, table, cols_by_orig_name, exclude_columns + ): pk_cons = self.get_pk_constraint( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) if pk_cons: pk_cols = [ cols_by_orig_name[pk] - for pk in pk_cons['constrained_columns'] + for pk in pk_cons["constrained_columns"] if pk in cols_by_orig_name and pk not in exclude_columns ] # update pk constraint name - table.primary_key.name = pk_cons.get('name') + table.primary_key.name = pk_cons.get("name") # tell the PKConstraint to re-initialize # its column collection table.primary_key._reload(pk_cols) def _reflect_fk( - self, table_name, schema, table, cols_by_orig_name, - exclude_columns, _extend_on, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + exclude_columns, + _extend_on, + reflection_options, + ): fkeys = self.get_foreign_keys( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) for fkey_d in fkeys: - conname = fkey_d['name'] + conname = fkey_d["name"] # look for columns by orig name in cols_by_orig_name, # but support columns that are in-Python only as fallback constrained_columns = [ - cols_by_orig_name[c].key - if c in cols_by_orig_name else c - for c in fkey_d['constrained_columns'] + cols_by_orig_name[c].key if c in cols_by_orig_name else c + for c in fkey_d["constrained_columns"] ] if exclude_columns and set(constrained_columns).intersection( - exclude_columns): + exclude_columns + ): continue - referred_schema = fkey_d['referred_schema'] - referred_table = fkey_d['referred_table'] - referred_columns = fkey_d['referred_columns'] + referred_schema = fkey_d["referred_schema"] + referred_table = fkey_d["referred_table"] + referred_columns = fkey_d["referred_columns"] refspec = [] if referred_schema is not None: - sa_schema.Table(referred_table, table.metadata, - autoload=True, schema=referred_schema, - autoload_with=self.bind, - _extend_on=_extend_on, - **reflection_options - ) + sa_schema.Table( + referred_table, + table.metadata, + autoload=True, + schema=referred_schema, + autoload_with=self.bind, + _extend_on=_extend_on, + **reflection_options + ) for column in referred_columns: - refspec.append(".".join( - [referred_schema, referred_table, column])) + refspec.append( + ".".join([referred_schema, referred_table, column]) + ) else: - sa_schema.Table(referred_table, table.metadata, autoload=True, - autoload_with=self.bind, - schema=sa_schema.BLANK_SCHEMA, - _extend_on=_extend_on, - **reflection_options - ) + sa_schema.Table( + referred_table, + table.metadata, + autoload=True, + autoload_with=self.bind, + schema=sa_schema.BLANK_SCHEMA, + _extend_on=_extend_on, + **reflection_options + ) for column in referred_columns: refspec.append(".".join([referred_table, column])) - if 'options' in fkey_d: - options = fkey_d['options'] + if "options" in fkey_d: + options = fkey_d["options"] else: options = {} table.append_constraint( - sa_schema.ForeignKeyConstraint(constrained_columns, refspec, - conname, link_to_name=True, - **options)) + sa_schema.ForeignKeyConstraint( + constrained_columns, + refspec, + conname, + link_to_name=True, + **options + ) + ) def _reflect_indexes( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): # Indexes indexes = self.get_indexes(table_name, schema) for index_d in indexes: - name = index_d['name'] - columns = index_d['column_names'] - unique = index_d['unique'] - flavor = index_d.get('type', 'index') - dialect_options = index_d.get('dialect_options', {}) - - duplicates = index_d.get('duplicates_constraint') - if include_columns and \ - not set(columns).issubset(include_columns): + name = index_d["name"] + columns = index_d["column_names"] + unique = index_d["unique"] + flavor = index_d.get("type", "index") + dialect_options = index_d.get("dialect_options", {}) + + duplicates = index_d.get("duplicates_constraint") + if include_columns and not set(columns).issubset(include_columns): util.warn( - "Omitting %s key for (%s), key covers omitted columns." % - (flavor, ', '.join(columns))) + "Omitting %s key for (%s), key covers omitted columns." + % (flavor, ", ".join(columns)) + ) continue if duplicates: continue @@ -802,26 +879,36 @@ class Inspector(object): idx_cols = [] for c in columns: try: - idx_col = cols_by_orig_name[c] \ - if c in cols_by_orig_name else table.c[c] + idx_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) except KeyError: util.warn( "%s key '%s' was not located in " - "columns for table '%s'" % ( - flavor, c, table_name - )) + "columns for table '%s'" % (flavor, c, table_name) + ) else: idx_cols.append(idx_col) sa_schema.Index( - name, *idx_cols, + name, + *idx_cols, _table=table, - **dict(list(dialect_options.items()) + [('unique', unique)]) + **dict(list(dialect_options.items()) + [("unique", unique)]) ) def _reflect_unique_constraints( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): # Unique Constraints try: @@ -831,15 +918,14 @@ class Inspector(object): return for const_d in constraints: - conname = const_d['name'] - columns = const_d['column_names'] - duplicates = const_d.get('duplicates_index') - if include_columns and \ - not set(columns).issubset(include_columns): + conname = const_d["name"] + columns = const_d["column_names"] + duplicates = const_d.get("duplicates_index") + if include_columns and not set(columns).issubset(include_columns): util.warn( "Omitting unique constraint key for (%s), " - "key covers omitted columns." % - ', '.join(columns)) + "key covers omitted columns." % ", ".join(columns) + ) continue if duplicates: continue @@ -848,20 +934,32 @@ class Inspector(object): constrained_cols = [] for c in columns: try: - constrained_col = cols_by_orig_name[c] \ - if c in cols_by_orig_name else table.c[c] + constrained_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) except KeyError: util.warn( "unique constraint key '%s' was not located in " - "columns for table '%s'" % (c, table_name)) + "columns for table '%s'" % (c, table_name) + ) else: constrained_cols.append(constrained_col) table.append_constraint( - sa_schema.UniqueConstraint(*constrained_cols, name=conname)) + sa_schema.UniqueConstraint(*constrained_cols, name=conname) + ) def _reflect_check_constraints( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): try: constraints = self.get_check_constraints(table_name, schema) except NotImplementedError: @@ -869,14 +967,14 @@ class Inspector(object): return for const_d in constraints: - table.append_constraint( - sa_schema.CheckConstraint(**const_d)) + table.append_constraint(sa_schema.CheckConstraint(**const_d)) def _reflect_table_comment( - self, table_name, schema, table, reflection_options): + self, table_name, schema, table, reflection_options + ): try: comment_dict = self.get_table_comment(table_name, schema) except NotImplementedError: return else: - table.comment = comment_dict.get('text', None) + table.comment = comment_dict.get("text", None) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index d4c862375..5ad0d2909 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -27,20 +27,25 @@ try: # the extension is present. def rowproxy_reconstructor(cls, state): return safe_rowproxy_reconstructor(cls, state) + + except ImportError: + def rowproxy_reconstructor(cls, state): obj = cls.__new__(cls) obj.__setstate__(state) return obj + try: from sqlalchemy.cresultproxy import BaseRowProxy + _baserowproxy_usecext = True except ImportError: _baserowproxy_usecext = False class BaseRowProxy(object): - __slots__ = ('_parent', '_row', '_processors', '_keymap') + __slots__ = ("_parent", "_row", "_processors", "_keymap") def __init__(self, parent, row, processors, keymap): """RowProxy objects are constructed by ResultProxy objects.""" @@ -51,8 +56,10 @@ except ImportError: self._keymap = keymap def __reduce__(self): - return (rowproxy_reconstructor, - (self.__class__, self.__getstate__())) + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) def values(self): """Return the values represented by this RowProxy as a list.""" @@ -76,8 +83,9 @@ except ImportError: except TypeError: if isinstance(key, slice): l = [] - for processor, value in zip(self._processors[key], - self._row[key]): + for processor, value in zip( + self._processors[key], self._row[key] + ): if processor is None: l.append(value) else: @@ -88,7 +96,8 @@ except ImportError: if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj) + "result set column descriptions" % obj + ) if processor is not None: return processor(self._row[index]) else: @@ -110,29 +119,29 @@ class RowProxy(BaseRowProxy): mapped to the original Columns that produced this result set (for results that correspond to constructed SQL expressions). """ + __slots__ = () def __contains__(self, key): return self._parent._has_key(key) def __getstate__(self): - return { - '_parent': self._parent, - '_row': tuple(self) - } + return {"_parent": self._parent, "_row": tuple(self)} def __setstate__(self, state): - self._parent = parent = state['_parent'] - self._row = state['_row'] + self._parent = parent = state["_parent"] + self._row = state["_row"] self._processors = parent._processors self._keymap = parent._keymap __hash__ = None def _op(self, other, op): - return op(tuple(self), tuple(other)) \ - if isinstance(other, RowProxy) \ + return ( + op(tuple(self), tuple(other)) + if isinstance(other, RowProxy) else op(tuple(self), other) + ) def __lt__(self, other): return self._op(other, operator.lt) @@ -176,6 +185,7 @@ class RowProxy(BaseRowProxy): def itervalues(self): return iter(self) + try: # Register RowProxy with Sequence, # so sequence protocol is implemented @@ -189,8 +199,13 @@ class ResultMetaData(object): context.""" __slots__ = ( - '_keymap', 'case_sensitive', 'matched_on_name', - '_processors', 'keys', '_orig_processors') + "_keymap", + "case_sensitive", + "matched_on_name", + "_processors", + "keys", + "_orig_processors", + ) def __init__(self, parent, cursor_description): context = parent.context @@ -200,18 +215,25 @@ class ResultMetaData(object): self._orig_processors = None if context.result_column_struct: - result_columns, cols_are_ordered, textual_ordered = \ + result_columns, cols_are_ordered, textual_ordered = ( context.result_column_struct + ) num_ctx_cols = len(result_columns) else: - result_columns = cols_are_ordered = \ - num_ctx_cols = textual_ordered = False + result_columns = ( + cols_are_ordered + ) = num_ctx_cols = textual_ordered = False # merge cursor.description with the column info # present in the compiled structure, if any raw = self._merge_cursor_description( - context, cursor_description, result_columns, - num_ctx_cols, cols_are_ordered, textual_ordered) + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ) self._keymap = {} if not _baserowproxy_usecext: @@ -223,23 +245,20 @@ class ResultMetaData(object): len_raw = len(raw) - self._keymap.update([ - (elem[0], (elem[3], elem[4], elem[0])) - for elem in raw - ] + [ - (elem[0] - len_raw, (elem[3], elem[4], elem[0])) - for elem in raw - ]) + self._keymap.update( + [(elem[0], (elem[3], elem[4], elem[0])) for elem in raw] + + [ + (elem[0] - len_raw, (elem[3], elem[4], elem[0])) + for elem in raw + ] + ) # processors in key order for certain per-row # views like __iter__ and slices self._processors = [elem[3] for elem in raw] # keymap by primary string... - by_key = dict([ - (elem[2], (elem[3], elem[4], elem[0])) - for elem in raw - ]) + by_key = dict([(elem[2], (elem[3], elem[4], elem[0])) for elem in raw]) # for compiled SQL constructs, copy additional lookup keys into # the key lookup map, such as Column objects, labels, @@ -264,29 +283,38 @@ class ResultMetaData(object): # copy secondary elements from compiled columns # into self._keymap, write in the potentially "ambiguous" # element - self._keymap.update([ - (obj_elem, by_key[elem[2]]) - for elem in raw if elem[4] - for obj_elem in elem[4] - ]) + self._keymap.update( + [ + (obj_elem, by_key[elem[2]]) + for elem in raw + if elem[4] + for obj_elem in elem[4] + ] + ) # if we did a pure positional match, then reset the # original "expression element" back to the "unambiguous" # entry. This is a new behavior in 1.1 which impacts # TextAsFrom but also straight compiled SQL constructs. if not self.matched_on_name: - self._keymap.update([ - (elem[4][0], (elem[3], elem[4], elem[0])) - for elem in raw if elem[4] - ]) + self._keymap.update( + [ + (elem[4][0], (elem[3], elem[4], elem[0])) + for elem in raw + if elem[4] + ] + ) else: # no dupes - copy secondary elements from compiled # columns into self._keymap - self._keymap.update([ - (obj_elem, (elem[3], elem[4], elem[0])) - for elem in raw if elem[4] - for obj_elem in elem[4] - ]) + self._keymap.update( + [ + (obj_elem, (elem[3], elem[4], elem[0])) + for elem in raw + if elem[4] + for obj_elem in elem[4] + ] + ) # update keymap with primary string names taking # precedence @@ -294,14 +322,19 @@ class ResultMetaData(object): # update keymap with "translated" names (sqlite-only thing) if not num_ctx_cols and context._translate_colname: - self._keymap.update([ - (elem[5], self._keymap[elem[2]]) - for elem in raw if elem[5] - ]) + self._keymap.update( + [(elem[5], self._keymap[elem[2]]) for elem in raw if elem[5]] + ) def _merge_cursor_description( - self, context, cursor_description, result_columns, - num_ctx_cols, cols_are_ordered, textual_ordered): + self, + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ): """Merge a cursor.description with compiled result column information. There are at least four separate strategies used here, selected @@ -357,10 +390,12 @@ class ResultMetaData(object): case_sensitive = context.dialect.case_sensitive - if num_ctx_cols and \ - cols_are_ordered and \ - not textual_ordered and \ - num_ctx_cols == len(cursor_description): + if ( + num_ctx_cols + and cols_are_ordered + and not textual_ordered + and num_ctx_cols == len(cursor_description) + ): self.keys = [elem[0] for elem in result_columns] # pure positional 1-1 case; doesn't need to read # the names from cursor.description @@ -373,9 +408,9 @@ class ResultMetaData(object): type_, key, cursor_description[idx][1] ), obj, - None - ) for idx, (key, name, obj, type_) - in enumerate(result_columns) + None, + ) + for idx, (key, name, obj, type_) in enumerate(result_columns) ] else: # name-based or text-positional cases, where we need @@ -383,26 +418,32 @@ class ResultMetaData(object): if textual_ordered: # textual positional case raw_iterator = self._merge_textual_cols_by_position( - context, cursor_description, result_columns) + context, cursor_description, result_columns + ) elif num_ctx_cols: # compiled SQL with a mismatch of description cols # vs. compiled cols, or textual w/ unordered columns raw_iterator = self._merge_cols_by_name( - context, cursor_description, result_columns) + context, cursor_description, result_columns + ) else: # no compiled SQL, just a raw string raw_iterator = self._merge_cols_by_none( - context, cursor_description) + context, cursor_description + ) return [ ( - idx, colname, colname, + idx, + colname, + colname, context.get_result_processor( - mapped_type, colname, coltype), - obj, untranslated) - - for idx, colname, mapped_type, coltype, obj, untranslated - in raw_iterator + mapped_type, colname, coltype + ), + obj, + untranslated, + ) + for idx, colname, mapped_type, coltype, obj, untranslated in raw_iterator ] def _colnames_from_description(self, context, cursor_description): @@ -416,10 +457,14 @@ class ResultMetaData(object): dialect = context.dialect case_sensitive = dialect.case_sensitive translate_colname = context._translate_colname - description_decoder = dialect._description_decoder \ - if dialect.description_encoding else None - normalize_name = dialect.normalize_name \ - if dialect.requires_name_normalize else None + description_decoder = ( + dialect._description_decoder + if dialect.description_encoding + else None + ) + normalize_name = ( + dialect.normalize_name if dialect.requires_name_normalize else None + ) untranslated = None self.keys = [] @@ -444,20 +489,25 @@ class ResultMetaData(object): yield idx, colname, untranslated, coltype def _merge_textual_cols_by_position( - self, context, cursor_description, result_columns): + self, context, cursor_description, result_columns + ): dialect = context.dialect num_ctx_cols = len(result_columns) if result_columns else None if num_ctx_cols > len(cursor_description): util.warn( "Number of columns in textual SQL (%d) is " - "smaller than number of columns requested (%d)" % ( - num_ctx_cols, len(cursor_description) - )) + "smaller than number of columns requested (%d)" + % (num_ctx_cols, len(cursor_description)) + ) seen = set() - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): if idx < num_ctx_cols: ctx_rec = result_columns[idx] obj = ctx_rec[2] @@ -465,7 +515,8 @@ class ResultMetaData(object): if obj[0] in seen: raise exc.InvalidRequestError( "Duplicate column expression requested " - "in textual SQL: %r" % obj[0]) + "in textual SQL: %r" % obj[0] + ) seen.add(obj[0]) else: mapped_type = sqltypes.NULLTYPE @@ -479,8 +530,12 @@ class ResultMetaData(object): result_map = self._create_result_map(result_columns, case_sensitive) self.matched_on_name = True - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): try: ctx_rec = result_map[colname] except KeyError: @@ -493,8 +548,12 @@ class ResultMetaData(object): def _merge_cols_by_none(self, context, cursor_description): dialect = context.dialect - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated @classmethod @@ -525,27 +584,28 @@ class ResultMetaData(object): # or colummn('name') constructs to ColumnElements, or after a # pickle/unpickle roundtrip elif isinstance(key, expression.ColumnElement): - if key._label and ( - key._label - if self.case_sensitive - else key._label.lower()) in map: - result = map[key._label - if self.case_sensitive - else key._label.lower()] - elif hasattr(key, 'name') and ( - key.name - if self.case_sensitive - else key.name.lower()) in map: + if ( + key._label + and (key._label if self.case_sensitive else key._label.lower()) + in map + ): + result = map[ + key._label if self.case_sensitive else key._label.lower() + ] + elif ( + hasattr(key, "name") + and (key.name if self.case_sensitive else key.name.lower()) + in map + ): # match is only on name. - result = map[key.name - if self.case_sensitive - else key.name.lower()] + result = map[ + key.name if self.case_sensitive else key.name.lower() + ] # search extra hard to make sure this # isn't a column/label name overlap. # this check isn't currently available if the row # was unpickled. - if result is not None and \ - result[1] is not None: + if result is not None and result[1] is not None: for obj in result[1]: if key._compare_name_for_result(obj): break @@ -554,8 +614,9 @@ class ResultMetaData(object): if result is None: if raiseerr: raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" % - expression._string_or_unprintable(key)) + "Could not locate column in row for column '%s'" + % expression._string_or_unprintable(key) + ) else: return None else: @@ -580,34 +641,35 @@ class ResultMetaData(object): if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj) + "result set column descriptions" % obj + ) return operator.itemgetter(index) def __getstate__(self): return { - '_pickled_keymap': dict( + "_pickled_keymap": dict( (key, index) for key, (processor, obj, index) in self._keymap.items() if isinstance(key, util.string_types + util.int_types) ), - 'keys': self.keys, + "keys": self.keys, "case_sensitive": self.case_sensitive, - "matched_on_name": self.matched_on_name + "matched_on_name": self.matched_on_name, } def __setstate__(self, state): # the row has been processed at pickling time so we don't need any # processor anymore - self._processors = [None for _ in range(len(state['keys']))] + self._processors = [None for _ in range(len(state["keys"]))] self._keymap = keymap = {} - for key, index in state['_pickled_keymap'].items(): + for key, index in state["_pickled_keymap"].items(): # not preserving "obj" here, unfortunately our # proxy comparison fails with the unpickle keymap[key] = (None, None, index) - self.keys = state['keys'] - self.case_sensitive = state['case_sensitive'] - self.matched_on_name = state['matched_on_name'] + self.keys = state["keys"] + self.case_sensitive = state["case_sensitive"] + self.matched_on_name = state["matched_on_name"] class ResultProxy(object): @@ -643,8 +705,9 @@ class ResultProxy(object): self.dialect = context.dialect self.cursor = self._saved_cursor = context.cursor self.connection = context.root_connection - self._echo = self.connection._echo and \ - context.engine._should_log_debug() + self._echo = ( + self.connection._echo and context.engine._should_log_debug() + ) self._init_metadata() def _getter(self, key, raiseerr=True): @@ -666,18 +729,22 @@ class ResultProxy(object): def _init_metadata(self): cursor_description = self._cursor_description() if cursor_description is not None: - if self.context.compiled and \ - 'compiled_cache' in self.context.execution_options: + if ( + self.context.compiled + and "compiled_cache" in self.context.execution_options + ): if self.context.compiled._cached_metadata: self._metadata = self.context.compiled._cached_metadata else: - self._metadata = self.context.compiled._cached_metadata = \ - ResultMetaData(self, cursor_description) + self._metadata = ( + self.context.compiled._cached_metadata + ) = ResultMetaData(self, cursor_description) else: self._metadata = ResultMetaData(self, cursor_description) if self._echo: self.context.engine.logger.debug( - "Col %r", tuple(x[0] for x in cursor_description)) + "Col %r", tuple(x[0] for x in cursor_description) + ) def keys(self): """Return the current set of string keys for rows.""" @@ -731,7 +798,8 @@ class ResultProxy(object): return self.context.rowcount except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, self.cursor, self.context) + e, None, None, self.cursor, self.context + ) @property def lastrowid(self): @@ -753,8 +821,8 @@ class ResultProxy(object): return self._saved_cursor.lastrowid except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self._saved_cursor, self.context) + e, None, None, self._saved_cursor, self.context + ) @property def returns_rows(self): @@ -913,17 +981,18 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " - "expression construct.") + "Statement is not an insert() " "expression construct." + ) elif self.context._is_explicit_returning: raise exc.InvalidRequestError( "Can't call inserted_primary_key " "when returning() " - "is used.") + "is used." + ) return self.context.inserted_primary_key @@ -938,12 +1007,12 @@ class ResultProxy(object): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isupdate: raise exc.InvalidRequestError( - "Statement is not an update() " - "expression construct.") + "Statement is not an update() " "expression construct." + ) elif self.context.executemany: return self.context.compiled_parameters else: @@ -960,12 +1029,12 @@ class ResultProxy(object): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " - "expression construct.") + "Statement is not an insert() " "expression construct." + ) elif self.context.executemany: return self.context.compiled_parameters else: @@ -1013,12 +1082,13 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( "Statement is not an insert() or update() " - "expression construct.") + "expression construct." + ) return self.context.postfetch_cols def prefetch_cols(self): @@ -1035,12 +1105,13 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( "Statement is not an insert() or update() " - "expression construct.") + "expression construct." + ) return self.context.prefetch_cols def supports_sane_rowcount(self): @@ -1086,7 +1157,7 @@ class ResultProxy(object): if self._metadata is None: raise exc.ResourceClosedError( "This result object does not return rows. " - "It has been closed automatically.", + "It has been closed automatically." ) elif self.closed: raise exc.ResourceClosedError("This result object is closed.") @@ -1106,8 +1177,9 @@ class ResultProxy(object): l.append(process_row(metadata, row, processors, keymap)) return l else: - return [process_row(metadata, row, processors, keymap) - for row in rows] + return [ + process_row(metadata, row, processors, keymap) for row in rows + ] def fetchall(self): """Fetch all rows, just like DB-API ``cursor.fetchall()``. @@ -1132,8 +1204,8 @@ class ResultProxy(object): return l except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def fetchmany(self, size=None): """Fetch many rows, just like DB-API @@ -1161,8 +1233,8 @@ class ResultProxy(object): return l except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def fetchone(self): """Fetch one row, just like DB-API ``cursor.fetchone()``. @@ -1190,8 +1262,8 @@ class ResultProxy(object): return None except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def first(self): """Fetch the first row and then close the result set unconditionally. @@ -1209,8 +1281,8 @@ class ResultProxy(object): row = self._fetchone_impl() except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) try: if row is not None: @@ -1268,7 +1340,8 @@ class BufferedRowResultProxy(ResultProxy): def _init_metadata(self): self._max_row_buffer = self.context.execution_options.get( - 'max_row_buffer', None) + "max_row_buffer", None + ) self.__buffer_rows() super(BufferedRowResultProxy, self)._init_metadata() @@ -1284,13 +1357,13 @@ class BufferedRowResultProxy(ResultProxy): 50: 100, 100: 250, 250: 500, - 500: 1000 + 500: 1000, } def __buffer_rows(self): if self.cursor is None: return - size = getattr(self, '_bufsize', 1) + size = getattr(self, "_bufsize", 1) self.__rowbuffer = collections.deque(self.cursor.fetchmany(size)) self._bufsize = self.size_growth.get(size, size) if self._max_row_buffer is not None: @@ -1385,8 +1458,9 @@ class BufferedColumnRow(RowProxy): row[index] = processor(row[index]) index += 1 row = tuple(row) - super(BufferedColumnRow, self).__init__(parent, row, - processors, keymap) + super(BufferedColumnRow, self).__init__( + parent, row, processors, keymap + ) class BufferedColumnResultProxy(ResultProxy): diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index d4f5185de..4aecb9537 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -51,18 +51,20 @@ class DefaultEngineStrategy(EngineStrategy): plugins = u._instantiate_plugins(kwargs) - u.query.pop('plugin', None) - kwargs.pop('plugins', None) + u.query.pop("plugin", None) + kwargs.pop("plugins", None) entrypoint = u._get_entrypoint() dialect_cls = entrypoint.get_dialect_cls(u) - if kwargs.pop('_coerce_config', False): + if kwargs.pop("_coerce_config", False): + def pop_kwarg(key, default=None): value = kwargs.pop(key, default) if key in dialect_cls.engine_config_types: value = dialect_cls.engine_config_types[key](value) return value + else: pop_kwarg = kwargs.pop @@ -72,7 +74,7 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: dialect_args[k] = pop_kwarg(k) - dbapi = kwargs.pop('module', None) + dbapi = kwargs.pop("module", None) if dbapi is None: dbapi_args = {} for k in util.get_func_kwargs(dialect_cls.dbapi): @@ -80,7 +82,7 @@ class DefaultEngineStrategy(EngineStrategy): dbapi_args[k] = pop_kwarg(k) dbapi = dialect_cls.dbapi(**dbapi_args) - dialect_args['dbapi'] = dbapi + dialect_args["dbapi"] = dbapi for plugin in plugins: plugin.handle_dialect_kwargs(dialect_cls, dialect_args) @@ -90,41 +92,43 @@ class DefaultEngineStrategy(EngineStrategy): # assemble connection arguments (cargs, cparams) = dialect.create_connect_args(u) - cparams.update(pop_kwarg('connect_args', {})) + cparams.update(pop_kwarg("connect_args", {})) cargs = list(cargs) # allow mutability # look for existing pool or create - pool = pop_kwarg('pool', None) + pool = pop_kwarg("pool", None) if pool is None: + def connect(connection_record=None): if dialect._has_events: for fn in dialect.dispatch.do_connect: connection = fn( - dialect, connection_record, cargs, cparams) + dialect, connection_record, cargs, cparams + ) if connection is not None: return connection return dialect.connect(*cargs, **cparams) - creator = pop_kwarg('creator', connect) + creator = pop_kwarg("creator", connect) - poolclass = pop_kwarg('poolclass', None) + poolclass = pop_kwarg("poolclass", None) if poolclass is None: poolclass = dialect_cls.get_pool_class(u) - pool_args = { - 'dialect': dialect - } + pool_args = {"dialect": dialect} # consume pool arguments from kwargs, translating a few of # the arguments - translate = {'logging_name': 'pool_logging_name', - 'echo': 'echo_pool', - 'timeout': 'pool_timeout', - 'recycle': 'pool_recycle', - 'events': 'pool_events', - 'use_threadlocal': 'pool_threadlocal', - 'reset_on_return': 'pool_reset_on_return', - 'pre_ping': 'pool_pre_ping', - 'use_lifo': 'pool_use_lifo'} + translate = { + "logging_name": "pool_logging_name", + "echo": "echo_pool", + "timeout": "pool_timeout", + "recycle": "pool_recycle", + "events": "pool_events", + "use_threadlocal": "pool_threadlocal", + "reset_on_return": "pool_reset_on_return", + "pre_ping": "pool_pre_ping", + "use_lifo": "pool_use_lifo", + } for k in util.get_cls_kwargs(poolclass): tk = translate.get(k, k) if tk in kwargs: @@ -149,7 +153,7 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: engine_args[k] = pop_kwarg(k) - _initialize = kwargs.pop('_initialize', True) + _initialize = kwargs.pop("_initialize", True) # all kwargs should be consumed if kwargs: @@ -157,32 +161,40 @@ class DefaultEngineStrategy(EngineStrategy): "Invalid argument(s) %s sent to create_engine(), " "using configuration %s/%s/%s. Please check that the " "keyword arguments are appropriate for this combination " - "of components." % (','.join("'%s'" % k for k in kwargs), - dialect.__class__.__name__, - pool.__class__.__name__, - engineclass.__name__)) + "of components." + % ( + ",".join("'%s'" % k for k in kwargs), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__, + ) + ) engine = engineclass(pool, dialect, u, **engine_args) if _initialize: do_on_connect = dialect.on_connect() if do_on_connect: + def on_connect(dbapi_connection, connection_record): conn = getattr( - dbapi_connection, '_sqla_unwrap', dbapi_connection) + dbapi_connection, "_sqla_unwrap", dbapi_connection + ) if conn is None: return do_on_connect(conn) - event.listen(pool, 'first_connect', on_connect) - event.listen(pool, 'connect', on_connect) + event.listen(pool, "first_connect", on_connect) + event.listen(pool, "connect", on_connect) def first_connect(dbapi_connection, connection_record): - c = base.Connection(engine, connection=dbapi_connection, - _has_events=False) + c = base.Connection( + engine, connection=dbapi_connection, _has_events=False + ) c._execution_options = util.immutabledict() dialect.initialize(c) - event.listen(pool, 'first_connect', first_connect, once=True) + + event.listen(pool, "first_connect", first_connect, once=True) dialect_cls.engine_created(engine) if entrypoint is not dialect_cls: @@ -197,18 +209,20 @@ class DefaultEngineStrategy(EngineStrategy): class PlainEngineStrategy(DefaultEngineStrategy): """Strategy for configuring a regular Engine.""" - name = 'plain' + name = "plain" engine_cls = base.Engine + PlainEngineStrategy() class ThreadLocalEngineStrategy(DefaultEngineStrategy): """Strategy for configuring an Engine with threadlocal behavior.""" - name = 'threadlocal' + name = "threadlocal" engine_cls = threadlocal.TLEngine + ThreadLocalEngineStrategy() @@ -220,7 +234,7 @@ class MockEngineStrategy(EngineStrategy): """ - name = 'mock' + name = "mock" def create(self, name_or_url, executor, **kwargs): # create url.URL object @@ -245,7 +259,7 @@ class MockEngineStrategy(EngineStrategy): self.execute = execute engine = property(lambda s: s) - dialect = property(attrgetter('_dialect')) + dialect = property(attrgetter("_dialect")) name = property(lambda s: s._dialect.name) schema_for_object = schema._schema_getter(None) @@ -258,29 +272,35 @@ class MockEngineStrategy(EngineStrategy): def compiler(self, statement, parameters, **kwargs): return self._dialect.compiler( - statement, parameters, engine=self, **kwargs) + statement, parameters, engine=self, **kwargs + ) def create(self, entity, **kwargs): - kwargs['checkfirst'] = False + kwargs["checkfirst"] = False from sqlalchemy.engine import ddl - ddl.SchemaGenerator( - self.dialect, self, **kwargs).traverse_single(entity) + ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single( + entity + ) def drop(self, entity, **kwargs): - kwargs['checkfirst'] = False + kwargs["checkfirst"] = False from sqlalchemy.engine import ddl - ddl.SchemaDropper( - self.dialect, self, **kwargs).traverse_single(entity) - def _run_visitor(self, visitorcallable, element, - connection=None, - **kwargs): - kwargs['checkfirst'] = False - visitorcallable(self.dialect, self, - **kwargs).traverse_single(element) + ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single( + entity + ) + + def _run_visitor( + self, visitorcallable, element, connection=None, **kwargs + ): + kwargs["checkfirst"] = False + visitorcallable(self.dialect, self, **kwargs).traverse_single( + element + ) def execute(self, object, *multiparams, **params): raise NotImplementedError() + MockEngineStrategy() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 0ec1f9613..5b2bdabc0 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -19,7 +19,6 @@ import weakref class TLConnection(base.Connection): - def __init__(self, *arg, **kw): super(TLConnection, self).__init__(*arg, **kw) self.__opencount = 0 @@ -43,6 +42,7 @@ class TLEngine(base.Engine): transactions. """ + _tl_connection_cls = TLConnection def __init__(self, *args, **kwargs): @@ -50,7 +50,7 @@ class TLEngine(base.Engine): self._connections = util.threading.local() def contextual_connect(self, **kw): - if not hasattr(self._connections, 'conn'): + if not hasattr(self._connections, "conn"): connection = None else: connection = self._connections.conn() @@ -60,29 +60,31 @@ class TLEngine(base.Engine): # or not connection.connection.is_valid: connection = self._tl_connection_cls( self, - self._wrap_pool_connect( - self.pool.connect, connection), - **kw) + self._wrap_pool_connect(self.pool.connect, connection), + **kw + ) self._connections.conn = weakref.ref(connection) return connection._increment_connect() def begin_twophase(self, xid=None): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_twophase(xid=xid)) + self.contextual_connect().begin_twophase(xid=xid) + ) return self def begin_nested(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_nested()) + self.contextual_connect().begin_nested() + ) return self def begin(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append(self.contextual_connect().begin()) return self @@ -97,21 +99,27 @@ class TLEngine(base.Engine): self.rollback() def prepare(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return self._connections.trans[-1].prepare() def commit(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.commit() def rollback(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.rollback() @@ -122,9 +130,11 @@ class TLEngine(base.Engine): @property def closed(self): - return not hasattr(self._connections, 'conn') or \ - self._connections.conn() is None or \ - self._connections.conn().closed + return ( + not hasattr(self._connections, "conn") + or self._connections.conn() is None + or self._connections.conn().closed + ) def close(self): if not self.closed: @@ -135,4 +145,4 @@ class TLEngine(base.Engine): self._connections.trans = [] def __repr__(self): - return 'TLEngine(%r)' % self.url + return "TLEngine(%r)" % self.url diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 1662efe20..e92e57b8e 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -50,8 +50,16 @@ class URL(object): """ - def __init__(self, drivername, username=None, password=None, - host=None, port=None, database=None, query=None): + def __init__( + self, + drivername, + username=None, + password=None, + host=None, + port=None, + database=None, + query=None, + ): self.drivername = drivername self.username = username self.password_original = password @@ -68,26 +76,26 @@ class URL(object): if self.username is not None: s += _rfc_1738_quote(self.username) if self.password is not None: - s += ':' + ('***' if hide_password - else _rfc_1738_quote(self.password)) + s += ":" + ( + "***" if hide_password else _rfc_1738_quote(self.password) + ) s += "@" if self.host is not None: - if ':' in self.host: + if ":" in self.host: s += "[%s]" % self.host else: s += self.host if self.port is not None: - s += ':' + str(self.port) + s += ":" + str(self.port) if self.database is not None: - s += '/' + self.database + s += "/" + self.database if self.query: keys = list(self.query) keys.sort() - s += '?' + "&".join( - "%s=%s" % ( - k, - element - ) for k in keys for element in util.to_list(self.query[k]) + s += "?" + "&".join( + "%s=%s" % (k, element) + for k in keys + for element in util.to_list(self.query[k]) ) return s @@ -101,14 +109,15 @@ class URL(object): return hash(str(self)) def __eq__(self, other): - return \ - isinstance(other, URL) and \ - self.drivername == other.drivername and \ - self.username == other.username and \ - self.password == other.password and \ - self.host == other.host and \ - self.database == other.database and \ - self.query == other.query + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + ) @property def password(self): @@ -122,20 +131,20 @@ class URL(object): self.password_original = password def get_backend_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.drivername else: - return self.drivername.split('+')[0] + return self.drivername.split("+")[0] def get_driver_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.get_dialect().driver else: - return self.drivername.split('+')[1] + return self.drivername.split("+")[1] def _instantiate_plugins(self, kwargs): - plugin_names = util.to_list(self.query.get('plugin', ())) - plugin_names += kwargs.get('plugins', []) + plugin_names = util.to_list(self.query.get("plugin", ())) + plugin_names += kwargs.get("plugins", []) return [ plugins.load(plugin_name)(self, kwargs) @@ -149,17 +158,19 @@ class URL(object): returned class implements the get_dialect_cls() method. """ - if '+' not in self.drivername: + if "+" not in self.drivername: name = self.drivername else: - name = self.drivername.replace('+', '.') + name = self.drivername.replace("+", ".") cls = registry.load(name) # check for legacy dialects that # would return a module with 'dialect' as the # actual class - if hasattr(cls, 'dialect') and \ - isinstance(cls.dialect, type) and \ - issubclass(cls.dialect, Dialect): + if ( + hasattr(cls, "dialect") + and isinstance(cls.dialect, type) + and issubclass(cls.dialect, Dialect) + ): return cls.dialect else: return cls @@ -187,7 +198,7 @@ class URL(object): """ translated = {} - attribute_names = ['host', 'database', 'username', 'password', 'port'] + attribute_names = ["host", "database", "username", "password", "port"] for sname in attribute_names: if names: name = names.pop(0) @@ -214,7 +225,8 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): - pattern = re.compile(r''' + pattern = re.compile( + r""" (?P<name>[\w\+]+):// (?: (?P<username>[^:/]*) @@ -228,21 +240,23 @@ def _parse_rfc1738_args(name): (?::(?P<port>[^/]*))? )? (?:/(?P<database>.*))? - ''', re.X) + """, + re.X, + ) m = pattern.match(name) if m is not None: components = m.groupdict() - if components['database'] is not None: - tokens = components['database'].split('?', 2) - components['database'] = tokens[0] + if components["database"] is not None: + tokens = components["database"].split("?", 2) + components["database"] = tokens[0] if len(tokens) > 1: query = {} for key, value in util.parse_qsl(tokens[1]): if util.py2k: - key = key.encode('ascii') + key = key.encode("ascii") if key in query: query[key] = util.to_list(query[key]) query[key].append(value) @@ -252,26 +266,27 @@ def _parse_rfc1738_args(name): query = None else: query = None - components['query'] = query + components["query"] = query - if components['username'] is not None: - components['username'] = _rfc_1738_unquote(components['username']) + if components["username"] is not None: + components["username"] = _rfc_1738_unquote(components["username"]) - if components['password'] is not None: - components['password'] = _rfc_1738_unquote(components['password']) + if components["password"] is not None: + components["password"] = _rfc_1738_unquote(components["password"]) - ipv4host = components.pop('ipv4host') - ipv6host = components.pop('ipv6host') - components['host'] = ipv4host or ipv6host - name = components.pop('name') + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") return URL(name, **components) else: raise exc.ArgumentError( - "Could not parse rfc1738 URL from string '%s'" % name) + "Could not parse rfc1738 URL from string '%s'" % name + ) def _rfc_1738_quote(text): - return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text) + return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) def _rfc_1738_unquote(text): @@ -279,7 +294,7 @@ def _rfc_1738_unquote(text): def _parse_keyvalue_args(name): - m = re.match(r'(\w+)://(.*)', name) + m = re.match(r"(\w+)://(.*)", name) if m is not None: (name, args) = m.group(1, 2) opts = dict(util.parse_qsl(args)) diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 17bc9a3b4..76bb8f4b5 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -46,28 +46,34 @@ def py_fallback(): elif len(multiparams) == 1: zero = multiparams[0] if isinstance(zero, (list, tuple)): - if not zero or hasattr(zero[0], '__iter__') and \ - not hasattr(zero[0], 'strip'): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): # execute(stmt, [{}, {}, {}, ...]) # execute(stmt, [(), (), (), ...]) return zero else: # execute(stmt, ("value", "value")) return [zero] - elif hasattr(zero, 'keys'): + elif hasattr(zero, "keys"): # execute(stmt, {"key":"value"}) return [zero] else: # execute(stmt, "value") return [[zero]] else: - if hasattr(multiparams[0], '__iter__') and \ - not hasattr(multiparams[0], 'strip'): + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): return multiparams else: return [multiparams] return locals() + + try: from sqlalchemy.cutils import _distill_params except ImportError: |
